/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.propagation.sgd.update;

import org.encog.neural.networks.training.propagation.sgd.StochasticGradientDescent;
import org.encog.neural.networks.training.propagation.sgd.update.UpdateRule;

public class AdamUpdate
implements UpdateRule {
    private StochasticGradientDescent training;
    private double[] m;
    private double[] v;
    private double beta1 = 0.9;
    private double beta2 = 0.999;
    private double eps = 1.0E-8;

    @Override
    public void init(StochasticGradientDescent theTraining) {
        this.training = theTraining;
        this.m = new double[theTraining.getFlat().getWeights().length];
        this.v = new double[theTraining.getFlat().getWeights().length];
    }

    @Override
    public void update(double[] gradients, double[] weights) {
        int i = 0;
        while (i < weights.length) {
            this.m[i] = this.beta1 * this.m[i] + (1.0 - this.beta1) * gradients[i];
            this.v[i] = this.beta2 * this.v[i] + (1.0 - this.beta2) * gradients[i] * gradients[i];
            double mCorrect = this.m[i] / (1.0 - Math.pow(this.beta1, this.training.getIteration()));
            double vCorrect = this.v[i] / (1.0 - Math.pow(this.beta2, this.training.getIteration()));
            double delta = this.training.getLearningRate() * mCorrect / (Math.sqrt(vCorrect) + this.eps);
            int n = i++;
            weights[n] = weights[n] + delta;
        }
    }
}

