package jsat.math.optimization.stochastic;

import java.util.Iterator;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;

/* loaded from: input_file:jsat/math/optimization/stochastic/Adam.class */
public class Adam implements GradientUpdater {
    private static final long serialVersionUID = 5352504067435579553L;
    private Vec m;
    private Vec v;
    private long t;
    private double alpha;
    private double beta_1;
    private double beta_2;
    private double eps;
    private double lambda;
    private double vBias;
    private double mBias;
    public static final double DEFAULT_ALPHA = 2.0E-4d;
    public static final double DEFAULT_BETA_1 = 0.1d;
    public static final double DEFAULT_BETA_2 = 0.001d;
    public static final double DEFAULT_EPS = 1.0E-8d;
    public static final double DEFAULT_LAMBDA = 1.0E-8d;

    public Adam() {
        this(2.0E-4d, 0.1d, 0.001d, 1.0E-8d, 1.0E-8d);
    }

    public Adam(double d, double d2, double d3, double d4, double d5) {
        if (d <= 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("alpha must be a positive value, not " + d);
        }
        if (d2 <= 0.0d || d2 > 1.0d || Double.isInfinite(d2) || Double.isNaN(d2)) {
            throw new IllegalArgumentException("beta_1 must be in (0, 1], not " + d2);
        }
        if (d3 <= 0.0d || d3 > 1.0d || Double.isInfinite(d3) || Double.isNaN(d3)) {
            throw new IllegalArgumentException("beta_2 must be in (0, 1], not " + d3);
        }
        if (Math.pow(1.0d - d2, 2.0d) / Math.sqrt(1.0d - d3) >= 1.0d) {
            throw new IllegalArgumentException("the required property (1-beta_1)^2 / sqrt(1-beta_2) < 1, is not held by beta_1=" + d2 + " and beta_2=" + d3);
        }
        if (d5 <= 0.0d || d5 >= 1.0d || Double.isInfinite(d5) || Double.isNaN(d5)) {
            throw new IllegalArgumentException("lambda must be in (0, 1), not " + d5);
        }
        this.alpha = d;
        this.beta_1 = d2;
        this.beta_2 = d3;
        this.eps = d4;
        this.lambda = d5;
    }

    public Adam(Adam adam) {
        this.alpha = adam.alpha;
        this.beta_1 = adam.beta_1;
        this.beta_2 = adam.beta_2;
        this.eps = adam.eps;
        this.lambda = adam.lambda;
        this.t = adam.t;
        this.mBias = adam.mBias;
        this.vBias = adam.vBias;
        if (adam.m != null) {
            this.m = adam.m.mo46clone();
            this.v = adam.v.mo46clone();
        }
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    public void update(Vec vec, Vec vec2, double d) {
        update(vec, vec2, d, 0.0d, 0.0d);
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    public double update(Vec vec, Vec vec2, double d, double d2, double d3) {
        this.t++;
        double pow = 1.0d - ((1.0d - this.beta_1) * Math.pow(this.lambda, this.t - 1));
        this.m.mutableMultiply(1.0d - pow);
        this.m.mutableAdd(pow, vec2);
        this.mBias = (1.0d - pow) + (pow * d3);
        this.v.mutableMultiply(1.0d - this.beta_2);
        this.vBias = ((1.0d - this.beta_2) * this.vBias) + (this.beta_2 * d3 * d3);
        Iterator<IndexValue> it = vec2.iterator();
        while (it.hasNext()) {
            IndexValue next = it.next();
            double value = next.getValue();
            this.v.increment(next.getIndex(), this.beta_2 * value * value);
        }
        double sqrt = ((d * this.alpha) * Math.sqrt(1.0d - Math.pow(1.0d - this.beta_2, this.t))) / (1.0d - Math.pow(1.0d - this.beta_1, this.t));
        for (int i = 0; i < this.m.length(); i++) {
            vec.increment(i, ((-sqrt) * this.m.get(i)) / (Math.sqrt(this.v.get(i)) + this.eps));
        }
        return (sqrt * this.mBias) / (Math.sqrt(this.vBias) + this.eps);
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Adam m244clone() {
        return new Adam(this);
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    public void setup(int i) {
        this.t = 0L;
        this.m = new ScaledVector(new DenseVector(i));
        this.v = new ScaledVector(new DenseVector(i));
        this.mBias = 0.0d;
        this.vBias = 0.0d;
    }
}
