/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.learning;

import edu.uci.jforests.learning.LearningProgressListener;
import edu.uci.jforests.learning.trees.Ensemble;
import edu.uci.jforests.learning.trees.decision.DecisionTree;
import edu.uci.jforests.learning.trees.regression.RegressionTree;
import edu.uci.jforests.sample.Sample;

public class LearningUtils {
    public static void updateScores(Sample sampleSet, double[] scores, Ensemble ensemble) {
        LearningUtils.updateScores(sampleSet, scores, ensemble, null);
    }

    public static void updateScores(Sample sampleSet, double[] scores, Ensemble ensemble, LearningProgressListener progressListener) {
        for (int t = 0; t < ensemble.getNumTrees(); ++t) {
            RegressionTree tree = (RegressionTree)ensemble.getTreeAt(t);
            double treeWeight = ensemble.getWeightAt(t);
            for (int i = 0; i < sampleSet.size; ++i) {
                int n = i;
                scores[n] = scores[n] + treeWeight * tree.getOutput(sampleSet.dataset, sampleSet.indicesInDataset[i]);
            }
            if (progressListener == null) continue;
            progressListener.onScoreEval();
        }
    }

    public static void updateScores(Sample sampleSet, double[] scores, RegressionTree tree, double treeWeight) {
        if (sampleSet.indicesInDataset == null) {
            for (int i = 0; i < sampleSet.size; ++i) {
                int n = i;
                scores[n] = scores[n] + treeWeight * tree.getOutput(sampleSet.dataset, i);
            }
        } else {
            for (int i = 0; i < sampleSet.size; ++i) {
                int n = i;
                scores[n] = scores[n] + treeWeight * tree.getOutput(sampleSet.dataset, sampleSet.indicesInDataset[i]);
            }
        }
    }

    public static void updateDistributions(Sample sampleSet, double[][] dist, DecisionTree tree, double treeWeight) {
        for (int i = 0; i < sampleSet.size; ++i) {
            double[] curDist = tree.getDistributionForInstance(sampleSet.dataset, sampleSet.indicesInDataset[i]);
            for (int c = 0; c < curDist.length; ++c) {
                double[] dArray = dist[i];
                int n = c;
                dArray[n] = dArray[n] + treeWeight * curDist[c];
            }
        }
    }

    public static void updateProbabilities(double[] prob, double[] scores, int size) {
        for (int i = 0; i < size; ++i) {
            prob[i] = 1.0 / (1.0 + Math.exp(-2.0 * scores[i]));
        }
    }

    public static void updateProbabilities(double[] prob, double[] scores, int[] instances, int size) {
        for (int i = 0; i < size; ++i) {
            int instance = instances[i];
            prob[instance] = 1.0 / (1.0 + Math.exp(-2.0 * scores[instance]));
        }
    }
}

