/*
 * Decompiled with CFR 0.152.
 */
package org.ea.javacnn.trainers;

import java.util.Arrays;
import org.ea.javacnn.JavaCNN;
import org.ea.javacnn.data.BackPropResult;
import org.ea.javacnn.trainers.Trainer;

public class AdamTrainer
extends Trainer {
    private final double beta1 = 0.9;
    private final double beta2 = 0.999;

    public AdamTrainer(JavaCNN net, int batch_size, float l2_decay) {
        super(net, batch_size, l2_decay);
    }

    @Override
    public void initTrainData(BackPropResult bpr) {
        double[] newXSumArr = new double[bpr.getWeights().length];
        Arrays.fill(newXSumArr, 0.0);
        this.xsum.add(newXSumArr);
    }

    @Override
    public void update(int i, int j, double gij, double[] p) {
        double[] gsumi = (double[])this.gsum.get(i);
        double[] xsumi = (double[])this.xsum.get(i);
        gsumi[j] = gsumi[j] * this.beta1 + (1.0 - this.beta1) * gij;
        xsumi[j] = xsumi[j] * this.beta2 + (1.0 - this.beta2) * gij * gij;
        double biasCorr1 = gsumi[j] * (1.0 - Math.pow(this.beta1, this.k));
        double biasCorr2 = xsumi[j] * (1.0 - Math.pow(this.beta2, this.k));
        double dx = -this.learning_rate * biasCorr1 / (Math.sqrt(biasCorr2) + this.eps);
        int n = j;
        p[n] = p[n] + dx;
    }
}

