/*
 * Decompiled with CFR 0.152.
 */
package org.ea.javacnn.trainers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.ea.javacnn.JavaCNN;
import org.ea.javacnn.data.BackPropResult;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.data.TrainResult;

public abstract class Trainer {
    private JavaCNN net;
    protected double learning_rate;
    protected double l1_decay;
    protected double l2_decay;
    protected int batch_size;
    protected int k;
    protected double momentum;
    protected double eps;
    protected List<double[]> gsum;
    protected List<double[]> xsum;

    public Trainer(JavaCNN net, int batch_size, float l2_decay) {
        this.net = net;
        this.learning_rate = 0.01;
        this.l1_decay = 0.001;
        this.l2_decay = l2_decay;
        this.batch_size = batch_size;
        this.momentum = 0.9;
        this.eps = 1.0E-8;
        this.gsum = new ArrayList<double[]>();
        this.xsum = new ArrayList<double[]>();
        this.k = 0;
    }

    public TrainResult train(DataBlock x, int y) {
        this.net.forward(x, true);
        double cost_loss = this.net.backward(y);
        double l2_decay_loss = 0.0;
        double l1_decay_loss = 0.0;
        ++this.k;
        if (this.k % this.batch_size == 0) {
            int i;
            List<BackPropResult> pglist = this.net.getBackPropagationResult();
            if (this.gsum.size() == 0 && this.momentum > 0.0) {
                for (i = 0; i < pglist.size(); ++i) {
                    double[] newGsumArr = new double[pglist.get(i).getWeights().length];
                    Arrays.fill(newGsumArr, 0.0);
                    this.gsum.add(newGsumArr);
                    this.initTrainData(pglist.get(i));
                }
            }
            for (i = 0; i < pglist.size(); ++i) {
                BackPropResult pg = pglist.get(i);
                double[] p = pg.getWeights();
                double[] g = pg.getGradients();
                double l2_decay_mul = pg.getL2DecayMul();
                double l1_decay_mul = pg.getL1DecayMul();
                double l2_decay = this.l2_decay * l2_decay_mul;
                double l1_decay = this.l1_decay * l1_decay_mul;
                int plen = p.length;
                for (int j = 0; j < plen; ++j) {
                    l2_decay_loss += l2_decay * p[j] * p[j] / 2.0;
                    l1_decay_loss += l1_decay * Math.abs(p[j]);
                    double l1grad = l1_decay * (double)(p[j] > 0.0 ? 1 : -1);
                    double l2grad = l2_decay * p[j];
                    double gij = (l2grad + l1grad + g[j]) / (double)this.batch_size;
                    this.update(i, j, gij, p);
                    g[j] = 0.0;
                }
            }
        }
        return new TrainResult(0L, 0L, l1_decay_loss, l2_decay_loss, cost_loss, cost_loss, cost_loss + l1_decay_loss + l2_decay_loss);
    }

    public abstract void update(int var1, int var2, double var3, double[] var5);

    public void initTrainData(BackPropResult bpr) {
    }
}

