/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.learning.trees.decision;

import edu.uci.jforests.dataset.Dataset;
import edu.uci.jforests.learning.trees.Tree;
import edu.uci.jforests.learning.trees.TreeSplit;
import edu.uci.jforests.learning.trees.decision.DecisionTreeSplit;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.ArraysUtil;
import edu.uci.jforests.util.MathUtil;

public class DecisionTree
extends Tree {
    private double[][] leafTargetDistributions;
    private int numClasses;

    public Object clone() {
        DecisionTree copy = new DecisionTree();
        super.copyTo(copy);
        copy.leafTargetDistributions = MathUtil.cloneDoubleMatrix(this.leafTargetDistributions);
        return copy;
    }

    public void init(int maxLeaves, int numClasses) {
        super.init(maxLeaves);
        this.numClasses = numClasses;
        this.leafTargetDistributions = new double[maxLeaves][numClasses];
    }

    public double[] getLeafTargetDistribution(int leaf) {
        return this.leafTargetDistributions[leaf];
    }

    public void setLeafTargetDistribution(int leaf, double[] dist) {
        System.arraycopy(dist, 0, this.leafTargetDistributions[leaf], 0, dist.length);
    }

    public int classify(Dataset dataset, int instanceIndex) {
        double[] dist = this.leafTargetDistributions[this.getLeaf(dataset, instanceIndex)];
        return ArraysUtil.findMaxIndex(dist);
    }

    public double[] getDistributionForInstance(Dataset dataset, int instanceIndex) {
        return this.leafTargetDistributions[this.getLeaf(dataset, instanceIndex)];
    }

    public int[] getPredictions(Dataset dataset) {
        int[] predictions = new int[dataset.numInstances];
        for (int i = 0; i < dataset.numInstances; ++i) {
            predictions[i] = this.classify(dataset, i);
        }
        return predictions;
    }

    @Override
    public int split(int leaf, TreeSplit split) {
        int indexOfNewNonLeaf = super.split(leaf, split);
        DecisionTreeSplit dsplit = (DecisionTreeSplit)split;
        for (int c = 0; c < this.numClasses; ++c) {
            this.leafTargetDistributions[leaf][c] = dsplit.leftTargetDist[c];
            this.leafTargetDistributions[this.numLeaves - 1][c] = dsplit.rightTargetDist[c];
        }
        this.normalizeLeafTargetDistributions(leaf);
        this.normalizeLeafTargetDistributions(this.numLeaves - 1);
        return indexOfNewNonLeaf;
    }

    private void normalizeLeafTargetDistributions(int leaf) {
        int c;
        double sum = 0.0;
        for (c = 0; c < this.numClasses; ++c) {
            sum += this.leafTargetDistributions[leaf][c];
        }
        c = 0;
        while (c < this.numClasses) {
            double[] dArray = this.leafTargetDistributions[leaf];
            int n = c++;
            dArray[n] = dArray[n] / sum;
        }
    }

    @Override
    public void loadCustomData(String str) throws Exception {
        this.leafTargetDistributions = ArraysUtil.loadDoubleMatrixFromLine(this.removeXmlTag(str, "LeafTargetDistributions"), this.numLeaves, this.numClasses);
    }

    @Override
    protected void addCustomData(String linePrefix, StringBuilder sb) {
        StringBuilder sbOutput = new StringBuilder();
        for (int n = 0; n < this.numLeaves; ++n) {
            for (int c = 0; c < this.numClasses; ++c) {
                sbOutput.append(" " + this.leafTargetDistributions[n][c]);
            }
        }
        sb.append("\n" + linePrefix + "\t<LeafTargetDistributions>" + sbOutput.toString().trim() + "</LeafTargetDistributions>");
    }

    @Override
    public void backfit(Sample sample) {
        int c;
        double[][] distPerLeaf = new double[this.numLeaves][this.numClasses];
        for (int i = 0; i < sample.size; ++i) {
            int leaf = this.getLeaf(sample.dataset, sample.indicesInDataset[i]);
            double[] dArray = distPerLeaf[leaf];
            int n = (int)sample.targets[i];
            dArray[n] = dArray[n] + sample.weights[i];
        }
        double[] weightedCountPerLeaf = new double[this.numLeaves];
        for (int l = 0; l < this.numLeaves; ++l) {
            for (int c2 = 0; c2 < this.numClasses; ++c2) {
                int n = l;
                weightedCountPerLeaf[n] = weightedCountPerLeaf[n] + distPerLeaf[l][c2];
            }
        }
        boolean hasZeroCountLeaf = false;
        double[][] distPerInternalNode = new double[this.numLeaves - 1][this.numClasses];
        for (int l = 0; l < this.numLeaves; ++l) {
            if (weightedCountPerLeaf[l] > 0.0) {
                this.setLeafTargetDistribution(l, distPerLeaf[l]);
                int parent = this.getParent(~l);
                while (parent >= 0) {
                    for (c = 0; c < this.numClasses; ++c) {
                        double[] dArray = distPerInternalNode[parent];
                        int n = c;
                        dArray[n] = dArray[n] + distPerLeaf[l][c];
                    }
                    parent = this.getParent(parent);
                }
                continue;
            }
            hasZeroCountLeaf = true;
        }
        if (hasZeroCountLeaf) {
            double[] weightedCountPerInternalNode = new double[this.numLeaves - 1];
            for (int i = 0; i < weightedCountPerInternalNode.length; ++i) {
                for (c = 0; c < this.numClasses; ++c) {
                    int n = i;
                    weightedCountPerInternalNode[n] = weightedCountPerInternalNode[n] + distPerInternalNode[i][c];
                }
            }
            block8: for (int l = 0; l < this.numLeaves; ++l) {
                if (weightedCountPerLeaf[l] != 0.0) continue;
                int parent = this.getParent(~l);
                while (parent >= 0) {
                    if (weightedCountPerInternalNode[parent] > 0.0) {
                        this.setLeafTargetDistribution(l, distPerInternalNode[parent]);
                        continue block8;
                    }
                    parent = this.getParent(parent);
                }
            }
        }
    }
}

