package edu.uci.jforests.learning.classification;

import edu.uci.jforests.eval.EvaluationMetric;
import edu.uci.jforests.learning.LearningUtils;
import edu.uci.jforests.learning.boosting.GradientBoosting;
import edu.uci.jforests.learning.boosting.GradientBoostingConfig;
import edu.uci.jforests.learning.trees.LeafInstances;
import edu.uci.jforests.learning.trees.Tree;
import edu.uci.jforests.learning.trees.TreeLeafInstances;
import edu.uci.jforests.learning.trees.regression.RegressionTree;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.ConfigHolder;
import java.util.Arrays;

/* loaded from: input_file:edu/uci/jforests/learning/classification/GradientBoostingBinaryClassifier.class */
public class GradientBoostingBinaryClassifier extends GradientBoosting {
    protected double[] balancingFactors;
    protected double[] prob;
    protected double[] validProb;
    protected double[] weights;
    private int[] subLearnerSampleIndicesInTrainSet;
    private boolean imbalanceCostAdjustment;

    public GradientBoostingBinaryClassifier() throws Exception {
        super("GradientBoostingBinaryClassifier");
    }

    @Override // edu.uci.jforests.learning.boosting.GradientBoosting
    public void init(ConfigHolder configHolder, int i, int i2, EvaluationMetric evaluationMetric) throws Exception {
        super.init(configHolder, i, i2, evaluationMetric);
        this.imbalanceCostAdjustment = ((GradientBoostingConfig) configHolder.getConfig(GradientBoostingConfig.class)).imbalanceCostAdjustment;
        this.prob = new double[i];
        this.validProb = new double[i2];
        this.weights = new double[i];
        this.subLearnerSampleIndicesInTrainSet = new int[i];
    }

    @Override // edu.uci.jforests.learning.boosting.GradientBoosting
    protected void preprocess() {
        double d;
        int i;
        if (this.balancingFactors == null || this.balancingFactors.length < this.curTrainSet.size) {
            this.balancingFactors = new double[this.residuals.length];
        }
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < this.curTrainSet.size; i4++) {
            if (this.curTrainSet.targets[i4] == 0.0d) {
                i3++;
            } else {
                i2++;
            }
        }
        if (this.imbalanceCostAdjustment) {
            for (int i5 = 0; i5 < this.curTrainSet.size; i5++) {
                double[] dArr = this.balancingFactors;
                int i6 = i5;
                if (this.curTrainSet.targets[i5] > 0.0d) {
                    d = 1.0d;
                    i = i2;
                } else {
                    d = 1.0d;
                    i = i3;
                }
                dArr[i6] = d / i;
            }
        } else {
            Arrays.fill(this.balancingFactors, 1.0d);
        }
        double d2 = i2 / (i2 + i3);
        double log = 0.5d * (Math.log((1.0d + d2) / (1.0d - d2)) / Math.log(2.0d));
        Arrays.fill(this.trainPredictions, 0, this.curTrainSet.size, log);
        if (this.curValidSet != null) {
            Arrays.fill(this.validPredictions, 0, this.curValidSet.size, log);
        }
    }

    @Override // edu.uci.jforests.learning.boosting.GradientBoosting
    protected double getValidMeasurement() throws Exception {
        LearningUtils.updateProbabilities(this.validProb, this.validPredictions, this.curValidSet.size);
        return this.curValidSet.evaluate(this.validProb, this.evaluationMetric);
    }

    @Override // edu.uci.jforests.learning.boosting.GradientBoosting
    protected Sample getSubLearnerSample() {
        for (int i = 0; i < this.curTrainSet.size; i++) {
            int i2 = this.curTrainSet.indicesInDataset[i];
            double d = this.curTrainSet.targets[i] == 0.0d ? -1 : 1;
            this.residuals[i2] = (2.0d * d) / (1.0d + Math.exp((2.0d * d) * this.trainPredictions[i]));
            double abs = Math.abs(this.residuals[i2]);
            this.weights[i2] = abs * (2.0d - abs);
        }
        Sample clone = this.curTrainSet.getRandomSubSample(this.samplingRate, this.rnd).getClone();
        clone.targets = this.residuals;
        for (int i3 = 0; i3 < clone.size; i3++) {
            this.subLearnerSampleIndicesInTrainSet[i3] = clone.indicesInParentSample[i3];
        }
        return clone;
    }

    protected double getAdjustedOutput(LeafInstances leafInstances) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = leafInstances.begin; i < leafInstances.end; i++) {
            int i2 = this.subLearnerSampleIndicesInTrainSet[leafInstances.indices[i]];
            d += this.residuals[i2] * this.balancingFactors[i2];
            d2 += this.weights[i2] * this.balancingFactors[i2];
        }
        return this.learningRate * ((d + 1.4E-45d) / (d2 + 1.4E-45d));
    }

    @Override // edu.uci.jforests.learning.boosting.GradientBoosting
    protected void adjustOutputs(Tree tree, TreeLeafInstances treeLeafInstances) {
        LeafInstances leafInstances = new LeafInstances();
        for (int i = 0; i < tree.numLeaves; i++) {
            treeLeafInstances.loadLeafInstances(i, leafInstances);
            ((RegressionTree) tree).setLeafOutput(i, getAdjustedOutput(leafInstances));
        }
    }

    @Override // edu.uci.jforests.learning.boosting.GradientBoosting
    protected void postProcessScores() {
        LearningUtils.updateProbabilities(this.prob, this.trainPredictions, this.curTrainSet.size);
    }
}
