/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.learning.boosting;

import edu.uci.jforests.config.TrainingConfig;
import edu.uci.jforests.eval.EvaluationMetric;
import edu.uci.jforests.learning.LearningModule;
import edu.uci.jforests.learning.LearningUtils;
import edu.uci.jforests.learning.boosting.GradientBoostingConfig;
import edu.uci.jforests.learning.trees.Ensemble;
import edu.uci.jforests.learning.trees.Tree;
import edu.uci.jforests.learning.trees.TreeLeafInstances;
import edu.uci.jforests.learning.trees.regression.RegressionTree;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.ConfigHolder;
import java.util.Arrays;
import java.util.Random;

public class GradientBoosting
extends LearningModule {
    protected double[] trainPredictions;
    protected double[] validPredictions;
    protected double[] residuals;
    protected int numInstances;
    private int numSubModules;
    protected double learningRate;
    protected double samplingRate;
    protected double earlyStoppingTolerance;
    protected Sample curTrainSet;
    protected Sample curValidSet;
    protected int curIteration;
    protected double bestValidationMeasurement;
    protected boolean printIntermediateValidMeasurements;
    protected EvaluationMetric evaluationMetric;
    protected Random rnd;

    public GradientBoosting(String algorithmName) {
        super(algorithmName);
    }

    public GradientBoosting() {
        super("GradientBoosting");
    }

    public void init(ConfigHolder configHolder, int maxNumTrainInstances, int maxNumValidInstances, EvaluationMetric evaluationMetric) throws Exception {
        this.evaluationMetric = evaluationMetric;
        GradientBoostingConfig gradientBoostingConfig = configHolder.getConfig(GradientBoostingConfig.class);
        this.numSubModules = gradientBoostingConfig.numTrees;
        this.learningRate = gradientBoostingConfig.learningRate;
        this.samplingRate = gradientBoostingConfig.samplingRate;
        this.earlyStoppingTolerance = gradientBoostingConfig.earlyStoppingTolerance;
        this.trainPredictions = new double[maxNumTrainInstances];
        this.residuals = new double[maxNumTrainInstances];
        this.validPredictions = new double[maxNumValidInstances];
        TrainingConfig trainingConfig = configHolder.getConfig(TrainingConfig.class);
        this.printIntermediateValidMeasurements = trainingConfig.printIntermediateValidMeasurements;
        this.rnd = new Random(trainingConfig.randomSeed);
    }

    protected void preprocess() {
        Arrays.fill(this.trainPredictions, 0, this.curTrainSet.size, 0.0);
        if (this.curValidSet != null) {
            Arrays.fill(this.validPredictions, 0, this.curValidSet.size, 0.0);
        }
    }

    @Override
    public Ensemble learn(Sample trainSet, Sample validSet) throws Exception {
        Sample subLeanerSample;
        Ensemble subEnsemble;
        this.curTrainSet = trainSet;
        this.curValidSet = validSet;
        this.preprocess();
        Ensemble ensemble = new Ensemble();
        this.bestValidationMeasurement = Double.NaN;
        int earlyStoppingIteration = 0;
        int bestIteration = 0;
        int[] treeCounts = new int[this.numSubModules];
        this.subLearner.setTreeWeight(this.treeWeight);
        this.curIteration = 1;
        while (this.curIteration <= this.numSubModules && (subEnsemble = this.subLearner.learn(subLeanerSample = this.getSubLearnerSample(), validSet)) != null) {
            for (int t = 0; t < subEnsemble.getNumTrees(); ++t) {
                Tree tree = subEnsemble.getTreeAt(t);
                ensemble.addTree(tree, subEnsemble.getWeightAt(t));
                if (validSet == null) continue;
                LearningUtils.updateScores(validSet, this.validPredictions, (RegressionTree)tree, 1.0);
            }
            treeCounts[this.curIteration - 1] = ensemble.getNumTrees();
            if (validSet == null) {
                earlyStoppingIteration = this.curIteration;
            } else {
                double validMeasurement = this.getValidMeasurement();
                if (this.evaluationMetric.isFirstBetter(validMeasurement, this.bestValidationMeasurement, this.earlyStoppingTolerance)) {
                    earlyStoppingIteration = this.curIteration;
                    if (this.evaluationMetric.isFirstBetter(validMeasurement, this.bestValidationMeasurement, 0.0)) {
                        this.bestValidationMeasurement = validMeasurement;
                        bestIteration = this.curIteration;
                    }
                }
                if (this.curIteration - bestIteration > 100) break;
                if (this.printIntermediateValidMeasurements) {
                    this.printTrainAndValidMeasurement(this.curIteration, validMeasurement, this.getTrainMeasurement(), this.evaluationMetric);
                }
            }
            this.onIterationEnd();
            ++this.curIteration;
        }
        if (earlyStoppingIteration > 0) {
            int treesToKeep = treeCounts[earlyStoppingIteration - 1];
            int treesToDelete = ensemble.getNumTrees() - treesToKeep;
            ensemble.removeLastTrees(treesToDelete);
        }
        this.onLearningEnd();
        return ensemble;
    }

    @Override
    public double getValidationMeasurement() {
        return this.bestValidationMeasurement;
    }

    protected double getValidMeasurement() throws Exception {
        return this.curValidSet.evaluate(this.validPredictions, this.evaluationMetric);
    }

    protected double getTrainMeasurement() throws Exception {
        return this.curTrainSet.evaluate(this.trainPredictions, this.evaluationMetric);
    }

    protected Sample getSubLearnerSample() {
        for (int i = 0; i < this.curTrainSet.size; ++i) {
            this.residuals[i] = this.curTrainSet.targets[i] - this.trainPredictions[i];
        }
        Sample subLearnerSample = this.curTrainSet.getClone();
        subLearnerSample.targets = this.residuals;
        subLearnerSample = subLearnerSample.getRandomSubSample(this.samplingRate, this.rnd);
        return subLearnerSample;
    }

    protected void adjustOutputs(Tree tree, TreeLeafInstances treeLeafInstances) {
        ((RegressionTree)tree).multiplyLeafOutputs(this.learningRate);
    }

    @Override
    public void postProcess(Tree tree, TreeLeafInstances treeLeafInstances) {
        this.adjustOutputs(tree, treeLeafInstances);
        LearningUtils.updateScores(this.curTrainSet, this.trainPredictions, (RegressionTree)tree, 1.0);
        this.postProcessScores();
    }

    protected void postProcessScores() {
    }
}

