package jsat.math.optimization;

import java.util.ArrayList;
import java.util.Iterator;
import jsat.classifiers.neuralnetwork.SOM;
import jsat.linear.ConstantVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.FunctionVec;
import jsat.utils.DoubleList;

/* loaded from: input_file:jsat/math/optimization/ModifiedOWLQN.class */
public class ModifiedOWLQN implements Optimizer {
    private int m;
    private double lambda;
    private Vec lambdaMultipler;
    private static final double DEFAULT_EPS = 1.0E-12d;
    private static final double DEFAULT_ALPHA_0 = 1.0d;
    private static final double DEFAULT_BETA = 0.2d;
    private static final double DEFAULT_GAMMA = 0.01d;
    private double eps;
    private double alpha_0;
    private double beta;
    private double gamma;
    private int maxIterations;

    public ModifiedOWLQN() {
        this(0.0d);
    }

    public ModifiedOWLQN(double d) {
        this.m = 10;
        this.lambdaMultipler = null;
        this.eps = DEFAULT_EPS;
        this.alpha_0 = 1.0d;
        this.beta = DEFAULT_BETA;
        this.gamma = 0.01d;
        this.maxIterations = SOM.DEFAULT_MAX_ITERS;
        setLambda(d);
    }

    protected ModifiedOWLQN(ModifiedOWLQN modifiedOWLQN) {
        this(modifiedOWLQN.lambda);
        if (modifiedOWLQN.lambdaMultipler != null) {
            this.lambdaMultipler = modifiedOWLQN.lambdaMultipler.mo46clone();
        }
        this.eps = modifiedOWLQN.eps;
        this.m = modifiedOWLQN.m;
        this.alpha_0 = modifiedOWLQN.alpha_0;
        this.beta = modifiedOWLQN.beta;
        this.gamma = modifiedOWLQN.gamma;
        this.maxIterations = modifiedOWLQN.maxIterations;
    }

    public void setLambda(double d) {
        if (d < 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("lambda must be non-negative, not " + d);
        }
        this.lambda = d;
    }

    public void setLambdaMultipler(Vec vec) {
        this.lambdaMultipler = vec;
    }

    public Vec getLambdaMultipler() {
        return this.lambdaMultipler;
    }

    public void setM(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("m must be positive, not " + i);
        }
        this.m = i;
    }

    public int getM() {
        return this.m;
    }

    public void setEps(double d) {
        if (d < 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("eps must be non-negative, not " + d);
        }
        this.eps = d;
    }

    public double getEps() {
        return this.eps;
    }

    public void setBeta(double d) {
        if (d <= 0.0d || d >= 1.0d || Double.isNaN(d)) {
            throw new IllegalArgumentException("shrinkage term must be in (0, 1), not " + d);
        }
        this.beta = d;
    }

    public double getBeta() {
        return this.beta;
    }

    @Override // jsat.math.optimization.Optimizer
    public void optimize(double d, Vec vec, Vec vec2, Function function, FunctionVec functionVec, boolean z) {
        double f;
        if (functionVec == null) {
            functionVec = Function.forwardDifference(function);
        }
        Vec vec3 = this.lambdaMultipler;
        if (this.lambdaMultipler == null) {
            vec3 = new ConstantVector(1.0d, vec2.length());
        }
        Vec mo46clone = vec2.mo46clone();
        Vec mo46clone2 = vec2.mo46clone();
        Vec mo46clone3 = vec2.mo46clone();
        Vec mo46clone4 = vec2.mo46clone();
        Vec mo46clone5 = vec2.mo46clone();
        Vec mo46clone6 = vec2.mo46clone();
        Vec mo46clone7 = vec2.mo46clone();
        Vec mo46clone8 = vec2.mo46clone();
        Vec mo46clone9 = vec2.mo46clone();
        DoubleList doubleList = new DoubleList(this.m);
        ArrayList arrayList = new ArrayList(this.m);
        ArrayList arrayList2 = new ArrayList(this.m);
        double[] dArr = new double[this.m];
        double f2 = function.f(mo46clone, z) + getL1Penalty(mo46clone, vec3);
        Vec f3 = functionVec.f(mo46clone, mo46clone2, z);
        for (int i = 0; i < this.maxIterations; i++) {
            double d2 = 0.0d;
            for (int i2 = 0; i2 < f3.length(); i2++) {
                double d3 = mo46clone.get(i2);
                double d4 = f3.get(i2);
                double d5 = this.lambda * vec3.get(i2);
                double d6 = d3 > 0.0d ? d4 + d5 : d3 < 0.0d ? d4 - d5 : d4 + d5 < 0.0d ? d4 + d5 : d4 - d5 > 0.0d ? d4 - d5 : 0.0d;
                mo46clone5.set(i2, -d6);
                d2 += d6 * d6;
            }
            double min = Math.min(Math.sqrt(d2), this.eps);
            boolean z2 = false;
            for (int i3 = 0; i3 < mo46clone5.length() && !z2; i3++) {
                double d7 = mo46clone.get(i3);
                if (0.0d < Math.abs(d7) && Math.abs(d7) < min && d7 * mo46clone5.get(i3) < 0.0d) {
                    z2 = true;
                }
            }
            double d8 = this.alpha_0;
            if (z2) {
                double d9 = d8 / this.beta;
                do {
                    d9 *= this.beta;
                    f3.copyTo(mo46clone8);
                    mo46clone8.mutableMultiply(-d9);
                    mo46clone8.mutableAdd(mo46clone);
                    for (int i4 = 0; i4 < mo46clone8.length(); i4++) {
                        double d10 = mo46clone8.get(i4);
                        mo46clone8.set(i4, Math.signum(d10) * Math.max(0.0d, Math.abs(d10) - ((this.lambda * vec3.get(i4)) * d9)));
                    }
                    mo46clone8.copyTo(mo46clone9);
                    mo46clone9.mutableSubtract(mo46clone);
                    f = function.f(mo46clone8, z) + getL1Penalty(mo46clone8, vec3);
                } while (f > f2 - ((this.gamma / (2.0d * d9)) * mo46clone9.dot(mo46clone9)));
            } else {
                LBFGS.twoLoopHp(mo46clone5, doubleList, arrayList, arrayList2, mo46clone6, dArr);
                for (int i5 = 0; i5 < mo46clone7.length(); i5++) {
                    if (Math.signum(mo46clone6.get(i5)) == Math.signum(mo46clone5.get(i5))) {
                        mo46clone7.set(i5, mo46clone6.get(i5));
                    } else {
                        mo46clone7.set(i5, 0.0d);
                    }
                }
                double dot = this.gamma * mo46clone5.dot(mo46clone6);
                double d11 = d8 / this.beta;
                do {
                    d11 *= this.beta;
                    mo46clone.copyTo(mo46clone8);
                    mo46clone8.mutableSubtract(-d11, mo46clone7);
                    for (int i6 = 0; i6 < mo46clone7.length(); i6++) {
                        double d12 = mo46clone.get(i6);
                        if (Math.signum(mo46clone8.get(i6)) != Math.signum(d12 != 0.0d ? d12 : mo46clone5.get(i6))) {
                            mo46clone8.set(i6, 0.0d);
                        }
                    }
                    f = function.f(mo46clone8, z) + getL1Penalty(mo46clone8, vec3);
                } while (f > f2 - (d11 * dot));
                mo46clone8.copyTo(mo46clone9);
                mo46clone9.mutableSubtract(mo46clone);
            }
            arrayList.add(0, mo46clone9.mo46clone());
            mo46clone3 = functionVec.f(mo46clone8, mo46clone3, z);
            double d13 = 0.0d;
            for (int i7 = 0; i7 < mo46clone3.length(); i7++) {
                d13 = Math.max(d13, Math.abs(mo46clone3.get(i7)));
            }
            if (d13 < d || f2 < d || mo46clone9.pNorm(1.0d) < d) {
                break;
            }
            mo46clone3.copyTo(mo46clone4);
            mo46clone4.mutableSubtract(f3);
            arrayList2.add(0, mo46clone4.mo46clone());
            doubleList.add(0, (int) Double.valueOf(1.0d / mo46clone9.dot(mo46clone4)));
            if (Double.isInfinite(doubleList.get(0).doubleValue()) || Double.isNaN(doubleList.get(0).doubleValue())) {
                doubleList.clear();
                arrayList.clear();
                arrayList2.clear();
            }
            while (doubleList.size() > this.m) {
                doubleList.remove(this.m);
                arrayList.remove(this.m);
                arrayList2.remove(this.m);
            }
            f2 = f;
            mo46clone8.copyTo(mo46clone);
            mo46clone3.copyTo(f3);
        }
        mo46clone.copyTo(vec);
    }

    private double getL1Penalty(Vec vec, Vec vec2) {
        if (this.lambda <= 0.0d) {
            return 0.0d;
        }
        double d = 0.0d;
        Iterator<IndexValue> it = vec.iterator();
        while (it.hasNext()) {
            IndexValue next = it.next();
            d += this.lambda * vec2.get(next.getIndex()) * Math.abs(next.getValue());
        }
        return d;
    }

    @Override // jsat.math.optimization.Optimizer
    public void setMaximumIterations(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of iterations must be positive, not " + i);
        }
        this.maxIterations = i;
    }

    @Override // jsat.math.optimization.Optimizer
    public int getMaximumIterations() {
        return this.maxIterations;
    }

    @Override // jsat.math.optimization.Optimizer
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public ModifiedOWLQN m234clone() {
        return new ModifiedOWLQN(this);
    }
}
