package org.ea.javacnn.trainers;

import org.ea.javacnn.JavaCNN;

/* loaded from: input_file:org/ea/javacnn/trainers/SGDTrainer.class */
public class SGDTrainer extends Trainer {
    public SGDTrainer(JavaCNN javaCNN, int i, float f) {
        super(javaCNN, i, f);
    }

    @Override // org.ea.javacnn.trainers.Trainer
    public void update(int i, int i2, double d, double[] dArr) {
        double[] dArr2 = this.gsum.get(i);
        if (this.momentum <= 0.0d) {
            dArr[i2] = dArr[i2] + ((-this.learning_rate) * d);
            return;
        }
        double d2 = (this.momentum * dArr2[i2]) - (this.learning_rate * d);
        dArr2[i2] = d2;
        dArr[i2] = dArr[i2] + d2;
    }
}
