package org.ea.javacnn.trainers;

import java.util.Arrays;
import org.ea.javacnn.JavaCNN;
import org.ea.javacnn.data.BackPropResult;

/* loaded from: input_file:org/ea/javacnn/trainers/AdamTrainer.class */
public class AdamTrainer extends Trainer {
    private final double beta1 = 0.9d;
    private final double beta2 = 0.999d;

    public AdamTrainer(JavaCNN javaCNN, int i, float f) {
        super(javaCNN, i, f);
        this.beta1 = 0.9d;
        this.beta2 = 0.999d;
    }

    @Override // org.ea.javacnn.trainers.Trainer
    public void initTrainData(BackPropResult backPropResult) {
        double[] dArr = new double[backPropResult.getWeights().length];
        Arrays.fill(dArr, 0.0d);
        this.xsum.add(dArr);
    }

    @Override // org.ea.javacnn.trainers.Trainer
    public void update(int i, int i2, double d, double[] dArr) {
        double[] dArr2 = this.gsum.get(i);
        double[] dArr3 = this.xsum.get(i);
        double d2 = dArr2[i2];
        getClass();
        getClass();
        dArr2[i2] = (d2 * 0.9d) + ((1.0d - 0.9d) * d);
        double d3 = dArr3[i2];
        getClass();
        getClass();
        dArr3[i2] = (d3 * 0.999d) + ((1.0d - 0.999d) * d * d);
        double d4 = dArr2[i2];
        getClass();
        double pow = d4 * (1.0d - Math.pow(0.9d, this.k));
        double d5 = dArr3[i2];
        getClass();
        dArr[i2] = dArr[i2] + (((-this.learning_rate) * pow) / (Math.sqrt(d5 * (1.0d - Math.pow(0.999d, this.k))) + this.eps));
    }
}
