package org.ea.javacnn.trainers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.ea.javacnn.JavaCNN;
import org.ea.javacnn.data.BackPropResult;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.data.TrainResult;

/* loaded from: input_file:org/ea/javacnn/trainers/Trainer.class */
public abstract class Trainer {
    private JavaCNN net;
    protected double l2_decay;
    protected int batch_size;
    protected double learning_rate = 0.01d;
    protected double l1_decay = 0.001d;
    protected double momentum = 0.9d;
    protected double eps = 1.0E-8d;
    protected List<double[]> gsum = new ArrayList();
    protected List<double[]> xsum = new ArrayList();
    protected int k = 0;

    public Trainer(JavaCNN javaCNN, int i, float f) {
        this.net = javaCNN;
        this.l2_decay = f;
        this.batch_size = i;
    }

    public TrainResult train(DataBlock dataBlock, int i) {
        this.net.forward(dataBlock, true);
        double backward = this.net.backward(i);
        double d = 0.0d;
        double d2 = 0.0d;
        this.k++;
        if (this.k % this.batch_size == 0) {
            List<BackPropResult> backPropagationResult = this.net.getBackPropagationResult();
            if (this.gsum.size() == 0 && this.momentum > 0.0d) {
                for (int i2 = 0; i2 < backPropagationResult.size(); i2++) {
                    double[] dArr = new double[backPropagationResult.get(i2).getWeights().length];
                    Arrays.fill(dArr, 0.0d);
                    this.gsum.add(dArr);
                    initTrainData(backPropagationResult.get(i2));
                }
            }
            for (int i3 = 0; i3 < backPropagationResult.size(); i3++) {
                BackPropResult backPropResult = backPropagationResult.get(i3);
                double[] weights = backPropResult.getWeights();
                double[] gradients = backPropResult.getGradients();
                double l2DecayMul = backPropResult.getL2DecayMul();
                double l1DecayMul = backPropResult.getL1DecayMul();
                double d3 = this.l2_decay * l2DecayMul;
                double d4 = this.l1_decay * l1DecayMul;
                int length = weights.length;
                for (int i4 = 0; i4 < length; i4++) {
                    d += ((d3 * weights[i4]) * weights[i4]) / 2.0d;
                    d2 += d4 * Math.abs(weights[i4]);
                    update(i3, i4, (((d3 * weights[i4]) + (d4 * (weights[i4] > 0.0d ? 1 : -1))) + gradients[i4]) / this.batch_size, weights);
                    gradients[i4] = 0.0d;
                }
            }
        }
        return new TrainResult(0L, 0L, d2, d, backward, backward, backward + d2 + d);
    }

    public abstract void update(int i, int i2, double d, double[] dArr);

    public void initTrainData(BackPropResult backPropResult) {
    }
}
