/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.learning.trees.regression;

import edu.uci.jforests.dataset.Dataset;
import edu.uci.jforests.dataset.Feature;
import edu.uci.jforests.dataset.Histogram;
import edu.uci.jforests.learning.trees.CandidateSplitsForLeaf;
import edu.uci.jforests.learning.trees.Tree;
import edu.uci.jforests.learning.trees.TreeLearner;
import edu.uci.jforests.learning.trees.TreeSplit;
import edu.uci.jforests.learning.trees.regression.RegressionCandidateSplitsForLeaf;
import edu.uci.jforests.learning.trees.regression.RegressionHistogram;
import edu.uci.jforests.learning.trees.regression.RegressionTree;
import edu.uci.jforests.learning.trees.regression.RegressionTreeSplit;
import edu.uci.jforests.learning.trees.regression.RegressionTreesConfig;
import edu.uci.jforests.util.ConfigHolder;

public class RegressionTreeLearner
extends TreeLearner {
    protected double maxLeafOutput;

    public RegressionTreeLearner() {
        super("RegressionTree");
    }

    @Override
    public void init(Dataset dataset, ConfigHolder configHolder, int maxTrainInstances) throws Exception {
        super.init(dataset, configHolder, maxTrainInstances);
        RegressionTreesConfig regressionTreesConfig = configHolder.getConfig(RegressionTreesConfig.class);
        this.maxLeafOutput = regressionTreesConfig.maxLeafOutput;
    }

    @Override
    protected Tree getNewTree() {
        RegressionTree tree = new RegressionTree();
        tree.init(this.maxLeaves, this.maxLeafOutput);
        return tree;
    }

    @Override
    protected TreeSplit getNewSplit() {
        return new RegressionTreeSplit();
    }

    @Override
    protected CandidateSplitsForLeaf getNewCandidateSplitsForLeaf(int numFeatures, int numInstances) {
        return new RegressionCandidateSplitsForLeaf(numFeatures, numInstances);
    }

    @Override
    protected Histogram getNewHistogram(Feature f) {
        return new RegressionHistogram(f);
    }

    @Override
    protected void setBestThresholdForSplit(TreeSplit split, Histogram histogram) {
        int bestThreshold;
        double bestWeightedLeftCount;
        double bestGain;
        double bestSumLeftTargets;
        RegressionHistogram regHistogram;
        block6: {
            double weightedLeftCount;
            int leftCount;
            double sumLeftTargets;
            block5: {
                regHistogram = (RegressionHistogram)histogram;
                bestSumLeftTargets = Double.NaN;
                bestGain = Double.NEGATIVE_INFINITY;
                bestWeightedLeftCount = -1.0;
                bestThreshold = 0;
                sumLeftTargets = 0.0;
                leftCount = 0;
                weightedLeftCount = 0.0;
                histogram.splittable = false;
                if (!this.randomizedSplits) break block5;
                int minIdx = 0;
                int maxIdx = histogram.numValues - 1;
                for (int t = 0; t < histogram.numValues - 1; ++t) {
                    sumLeftTargets += regHistogram.perValueSumTargets[t];
                    if ((leftCount += histogram.perValueCount[t]) < this.minInstancesPerLeaf) {
                        minIdx = t;
                        continue;
                    }
                    if (histogram.totalCount - leftCount < this.minInstancesPerLeaf) {
                        maxIdx = t + 1;
                        break;
                    }
                    histogram.splittable = true;
                }
                int range = maxIdx - minIdx;
                int randThresholdIdx = minIdx + this.rand.nextInt(range);
                sumLeftTargets = 0.0;
                leftCount = 0;
                weightedLeftCount = 0.0;
                if (!histogram.splittable) break block6;
                for (int t = 0; t < randThresholdIdx; ++t) {
                    sumLeftTargets += regHistogram.perValueSumTargets[t];
                    leftCount += histogram.perValueCount[t];
                    weightedLeftCount += histogram.perValueWeightedCount[t];
                }
                double sumRightTargets = regHistogram.sumTargets - sumLeftTargets;
                double weightedRightCount = histogram.totalWeightedCount - weightedLeftCount;
                double currentGain = sumLeftTargets * sumLeftTargets / weightedLeftCount + sumRightTargets * sumRightTargets / weightedRightCount;
                if (!(currentGain > bestGain)) break block6;
                bestWeightedLeftCount = weightedLeftCount;
                bestSumLeftTargets = sumLeftTargets;
                bestThreshold = randThresholdIdx;
                bestGain = currentGain;
                break block6;
            }
            for (int t = 0; t < histogram.numValues - 1; ++t) {
                weightedLeftCount += histogram.perValueWeightedCount[t];
                sumLeftTargets += regHistogram.perValueSumTargets[t];
                if ((leftCount += histogram.perValueCount[t]) < this.minInstancesPerLeaf || leftCount == 0) continue;
                int rightCount = histogram.totalCount - leftCount;
                if (rightCount < this.minInstancesPerLeaf || rightCount == 0) break;
                histogram.splittable = true;
                double sumRightTargets = regHistogram.sumTargets - sumLeftTargets;
                double weightedRightCount = histogram.totalWeightedCount - weightedLeftCount;
                double currentGain = sumLeftTargets * sumLeftTargets / weightedLeftCount + sumRightTargets * sumRightTargets / weightedRightCount;
                if (!(currentGain > bestGain)) continue;
                bestWeightedLeftCount = weightedLeftCount;
                bestSumLeftTargets = sumLeftTargets;
                bestThreshold = t;
                bestGain = currentGain;
            }
        }
        Feature feature = this.curTrainSet.dataset.features[split.feature];
        split.threshold = feature.upperBounds[bestThreshold];
        split.originalThreshold = feature.getOriginalValue(split.threshold);
        RegressionTreeSplit regressionSplit = (RegressionTreeSplit)split;
        regressionSplit.leftOutput = bestSumLeftTargets / bestWeightedLeftCount;
        regressionSplit.rightOutput = (regHistogram.sumTargets - bestSumLeftTargets) / (histogram.totalWeightedCount - bestWeightedLeftCount);
        split.gain = bestGain - regHistogram.sumTargets * regHistogram.sumTargets / histogram.totalWeightedCount;
    }
}

