package recunn.trainer;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import recunn.autodiff.Graph;
import recunn.datastructs.DataSequence;
import recunn.datastructs.DataSet;
import recunn.datastructs.DataStep;
import recunn.loss.Loss;
import recunn.matrix.Matrix;
import recunn.model.Model;

/* loaded from: input_file:recunn/trainer/Trainer.class */
public class Trainer {
    public static double decayRate = 0.999d;
    public static double smoothEpsilon = 1.0E-8d;
    public static double gradientClipValue = 5.0d;
    public static double regularization = 1.0E-6d;

    public static double train(int i, double d, Model model, DataSet dataSet, int i2, Random random) throws Exception {
        return train(i, d, model, dataSet, i2, false, false, null, random);
    }

    /* JADX WARN: Code restructure failed: missing block: B:41:0x01e9, code lost:
    
        return r18;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public static double train(int r8, double r9, recunn.model.Model r11, recunn.datastructs.DataSet r12, int r13, boolean r14, boolean r15, java.lang.String r16, java.util.Random r17) throws java.lang.Exception {
        /*
            Method dump skipped, instructions count: 490
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: recunn.trainer.Trainer.train(int, double, recunn.model.Model, recunn.datastructs.DataSet, int, boolean, boolean, java.lang.String, java.util.Random):double");
    }

    public static double pass(double d, Model model, List<DataSequence> list, boolean z, Loss loss, Loss loss2) throws Exception {
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (DataSequence dataSequence : list) {
            model.resetState();
            Graph graph = new Graph(z);
            for (DataStep dataStep : dataSequence.steps) {
                Matrix forward = model.forward(dataStep.input, graph);
                if (dataStep.targetOutput != null) {
                    double measure = loss2.measure(forward, dataStep.targetOutput);
                    if (Double.isNaN(measure) || Double.isInfinite(measure)) {
                        return measure;
                    }
                    d2 += measure;
                    d3 += 1.0d;
                    if (z) {
                        loss.backward(forward, dataStep.targetOutput);
                    }
                }
            }
            new ArrayList().add(dataSequence);
            if (z) {
                graph.backward();
                updateModelParams(model, d);
            }
        }
        return d2 / d3;
    }

    public static void updateModelParams(Model model, double d) throws Exception {
        for (Matrix matrix : model.getParameters()) {
            for (int i = 0; i < matrix.w.length; i++) {
                double d2 = matrix.dw[i];
                matrix.stepCache[i] = (matrix.stepCache[i] * decayRate) + ((1.0d - decayRate) * d2 * d2);
                if (d2 > gradientClipValue) {
                    d2 = gradientClipValue;
                }
                if (d2 < (-gradientClipValue)) {
                    d2 = -gradientClipValue;
                }
                double[] dArr = matrix.w;
                int i2 = i;
                dArr[i2] = dArr[i2] + ((((-d) * d2) / Math.sqrt(matrix.stepCache[i] + smoothEpsilon)) - (regularization * matrix.w[i]));
                matrix.dw[i] = 0.0d;
            }
        }
    }
}
