/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.learning.trees.decision;

import edu.uci.jforests.eval.EvaluationMetric;
import edu.uci.jforests.learning.LearningUtils;
import edu.uci.jforests.learning.trees.Tree;
import edu.uci.jforests.learning.trees.decision.DecisionTree;
import edu.uci.jforests.sample.Predictions;
import edu.uci.jforests.util.ArraysUtil;
import java.util.Arrays;

public class DecisionPredictions
extends Predictions {
    protected double[][] perInstanceDistribution;
    protected double[] perInstancePredictions;
    protected int numClasses;

    public DecisionPredictions(int numClasses) {
        this.numClasses = numClasses;
    }

    @Override
    public void allocate(int maxNumInstances) {
        this.perInstanceDistribution = new double[maxNumInstances][this.numClasses];
        this.perInstancePredictions = new double[maxNumInstances];
        for (int i = 0; i < maxNumInstances; ++i) {
            this.perInstanceDistribution[i] = new double[this.numClasses];
        }
    }

    @Override
    public void update(Tree tree, double weight) {
        LearningUtils.updateDistributions(this.sample, this.perInstanceDistribution, (DecisionTree)tree, weight);
    }

    @Override
    public double evaluate(EvaluationMetric evalMetric) throws Exception {
        if (this.numClasses == 2) {
            for (int i = 0; i < this.sample.size; ++i) {
                this.perInstancePredictions[i] = this.perInstanceDistribution[i][1] / (this.perInstanceDistribution[i][0] + this.perInstanceDistribution[i][1]);
            }
        } else {
            for (int i = 0; i < this.sample.size; ++i) {
                this.perInstancePredictions[i] = ArraysUtil.findMaxIndex(this.perInstanceDistribution[i]);
            }
        }
        return this.sample.evaluate(this.perInstancePredictions, evalMetric);
    }

    @Override
    public void reset() {
        Arrays.fill(this.perInstancePredictions, 0.0);
    }
}

