/*
 * Decompiled with CFR 0.152.
 */
package recunn.trainer;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import recunn.autodiff.Graph;
import recunn.datastructs.DataSequence;
import recunn.datastructs.DataSet;
import recunn.datastructs.DataStep;
import recunn.loss.Loss;
import recunn.matrix.Matrix;
import recunn.model.Model;
import recunn.util.FileIO;

public class Trainer {
    public static double decayRate = 0.999;
    public static double smoothEpsilon = 1.0E-8;
    public static double gradientClipValue = 5.0;
    public static double regularization = 1.0E-6;

    public static double train(int trainingEpochs, double learningRate, Model model, DataSet data, int reportEveryNthEpoch, Random rng) throws Exception {
        return Trainer.train(trainingEpochs, learningRate, model, data, reportEveryNthEpoch, false, false, null, rng);
    }

    public static double train(int trainingEpochs, double learningRate, Model model, DataSet data, int reportEveryNthEpoch, boolean initFromSaved, boolean overwriteSaved, String savePath, Random rng) throws Exception {
        System.out.println("--------------------------------------------------------------");
        if (initFromSaved) {
            System.out.println("initializing model from saved state...");
            try {
                model = (Model)FileIO.deserialize(savePath);
                data.DisplayReport(model, rng);
            }
            catch (Exception e) {
                System.out.println("Oops. Unable to load from a saved state.");
                System.out.println("WARNING: " + e.getMessage());
                System.out.println("Continuing from freshly initialized model instead.");
            }
        }
        double result = 1.0;
        for (int epoch = 0; epoch < trainingEpochs; ++epoch) {
            double reportedLossTrain;
            String show = "epoch[" + (epoch + 1) + "/" + trainingEpochs + "]";
            result = reportedLossTrain = Trainer.pass(learningRate, model, data.training, true, data.lossTraining, data.lossReporting);
            if (Double.isNaN(reportedLossTrain) || Double.isInfinite(reportedLossTrain)) {
                throw new Exception("WARNING: invalid value for training loss. Try lowering learning rate.");
            }
            double reportedLossValidation = 0.0;
            double reportedLossTesting = 0.0;
            if (data.validation != null) {
                result = reportedLossValidation = Trainer.pass(learningRate, model, data.validation, false, data.lossTraining, data.lossReporting);
            }
            if (data.testing != null) {
                result = reportedLossTesting = Trainer.pass(learningRate, model, data.testing, false, data.lossTraining, data.lossReporting);
            }
            show = show + "\ttrain loss = " + String.format("%.5f", reportedLossTrain);
            if (data.validation != null) {
                show = show + "\tvalid loss = " + String.format("%.5f", reportedLossValidation);
            }
            if (data.testing != null) {
                show = show + "\ttest loss  = " + String.format("%.5f", reportedLossTesting);
            }
            System.out.println(show);
            if (epoch % reportEveryNthEpoch == reportEveryNthEpoch - 1) {
                data.DisplayReport(model, rng);
            }
            if (overwriteSaved) {
                FileIO.serialize(savePath, model);
            }
            if (reportedLossTrain != 0.0 || reportedLossValidation != 0.0) continue;
            System.out.println("--------------------------------------------------------------");
            System.out.println("\nDONE.");
            break;
        }
        return result;
    }

    public static double pass(double learningRate, Model model, List<DataSequence> sequences, boolean applyTraining, Loss lossTraining, Loss lossReporting) throws Exception {
        double numerLoss = 0.0;
        double denomLoss = 0.0;
        for (DataSequence seq : sequences) {
            model.resetState();
            Graph g = new Graph(applyTraining);
            for (DataStep step : seq.steps) {
                Matrix output = model.forward(step.input, g);
                if (step.targetOutput == null) continue;
                double loss = lossReporting.measure(output, step.targetOutput);
                if (Double.isNaN(loss) || Double.isInfinite(loss)) {
                    return loss;
                }
                numerLoss += loss;
                denomLoss += 1.0;
                if (!applyTraining) continue;
                lossTraining.backward(output, step.targetOutput);
            }
            ArrayList<DataSequence> thisSequence = new ArrayList<DataSequence>();
            thisSequence.add(seq);
            if (!applyTraining) continue;
            g.backward();
            Trainer.updateModelParams(model, learningRate);
        }
        return numerLoss / denomLoss;
    }

    public static void updateModelParams(Model model, double stepSize) throws Exception {
        for (Matrix m : model.getParameters()) {
            for (int i = 0; i < m.w.length; ++i) {
                double mdwi = m.dw[i];
                m.stepCache[i] = m.stepCache[i] * decayRate + (1.0 - decayRate) * mdwi * mdwi;
                if (mdwi > gradientClipValue) {
                    mdwi = gradientClipValue;
                }
                if (mdwi < -gradientClipValue) {
                    mdwi = -gradientClipValue;
                }
                int n = i;
                m.w[n] = m.w[n] + (-stepSize * mdwi / Math.sqrt(m.stepCache[i] + smoothEpsilon) - regularization * m.w[i]);
                m.dw[i] = 0.0;
            }
        }
    }
}

