/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.ClassifierTrainer;
import smile.classification.DecisionTree;
import smile.classification.SoftClassifier;
import smile.data.Attribute;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.util.MulticoreExecutor;
import smile.util.SmileUtils;
import smile.validation.Accuracy;
import smile.validation.ClassificationMeasure;

public class RandomForest
implements SoftClassifier<double[]>,
Serializable {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = LoggerFactory.getLogger(RandomForest.class);
    private List<Tree> trees;
    private int k = 2;
    private double error;
    private double[] importance;

    public RandomForest(double[][] x, int[] y, int ntrees) {
        this(null, x, y, ntrees);
    }

    public RandomForest(double[][] x, int[] y, int ntrees, int mtry) {
        this(null, x, y, ntrees, mtry);
    }

    public RandomForest(Attribute[] attributes, double[][] x, int[] y, int ntrees) {
        this(attributes, x, y, ntrees, (int)Math.floor(Math.sqrt(x[0].length)));
    }

    public RandomForest(Attribute[] attributes, double[][] x, int[] y, int ntrees, int mtry) {
        this(attributes, x, y, ntrees, 100, 5, mtry, 1.0);
    }

    public RandomForest(Attribute[] attributes, double[][] x, int[] y, int ntrees, int maxNodes, int nodeSize, int mtry, double subsample) {
        this(attributes, x, y, ntrees, maxNodes, nodeSize, mtry, subsample, DecisionTree.SplitRule.GINI);
    }

    public RandomForest(Attribute[] attributes, double[][] x, int[] y, int ntrees, int maxNodes, int nodeSize, int mtry, double subsample, DecisionTree.SplitRule rule) {
        this(attributes, x, y, ntrees, maxNodes, nodeSize, mtry, subsample, rule, null);
    }

    public RandomForest(Attribute[] attributes, double[][] x, int[] y, int ntrees, int maxNodes, int nodeSize, int mtry, double subsample, DecisionTree.SplitRule rule, int[] classWeight) {
        int i;
        int i2;
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (ntrees < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
        }
        if (mtry < 1 || mtry > x[0].length) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + mtry);
        }
        if (nodeSize < 1) {
            throw new IllegalArgumentException("Invalid minimum size of leaves: " + nodeSize);
        }
        if (maxNodes < 2) {
            throw new IllegalArgumentException("Invalid maximum number of leaves: " + maxNodes);
        }
        if (subsample <= 0.0 || subsample > 1.0) {
            throw new IllegalArgumentException("Invalid sampling rating: " + subsample);
        }
        int[] labels = Math.unique(y);
        Arrays.sort(labels);
        for (i2 = 0; i2 < labels.length; ++i2) {
            if (labels[i2] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i2]);
            }
            if (i2 <= 0 || labels[i2] - labels[i2 - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + labels[i2] + 1);
        }
        this.k = labels.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (attributes == null) {
            int p = x[0].length;
            attributes = new Attribute[p];
            for (int i3 = 0; i3 < p; ++i3) {
                attributes[i3] = new NumericAttribute("V" + (i3 + 1));
            }
        }
        if (classWeight == null) {
            classWeight = new int[this.k];
            for (i2 = 0; i2 < this.k; ++i2) {
                classWeight[i2] = 1;
            }
        }
        int n = x.length;
        int[][] prediction = new int[n][this.k];
        int[][] order = SmileUtils.sort(attributes, x);
        ArrayList<TrainingTask> tasks = new ArrayList<TrainingTask>();
        for (int i4 = 0; i4 < ntrees; ++i4) {
            tasks.add(new TrainingTask(attributes, x, y, maxNodes, nodeSize, mtry, subsample, rule, classWeight, order, prediction));
        }
        try {
            this.trees = MulticoreExecutor.run(tasks);
        }
        catch (Exception ex) {
            logger.error("Failed to train random forest on multi-core", ex);
            this.trees = new ArrayList<Tree>(ntrees);
            for (i = 0; i < ntrees; ++i) {
                this.trees.add(((TrainingTask)tasks.get(i)).call());
            }
        }
        int m = 0;
        for (i = 0; i < n; ++i) {
            int pred = Math.whichMax(prediction[i]);
            if (prediction[i][pred] <= 0) continue;
            ++m;
            if (pred == y[i]) continue;
            this.error += 1.0;
        }
        if (m > 0) {
            this.error /= (double)m;
        }
        this.importance = new double[attributes.length];
        for (Tree tree : this.trees) {
            double[] imp = tree.tree.importance();
            for (int i5 = 0; i5 < imp.length; ++i5) {
                int n2 = i5;
                this.importance[n2] = this.importance[n2] + imp[i5];
            }
        }
    }

    public double error() {
        return this.error;
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.trees.size();
    }

    public void trim(int ntrees) {
        if (ntrees > this.trees.size()) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (ntrees <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + ntrees);
        }
        ArrayList<Tree> model = new ArrayList<Tree>(ntrees);
        for (int i = 0; i < ntrees; ++i) {
            model.add(this.trees.get(i));
        }
        this.trees = model;
    }

    @Override
    public int predict(double[] x) {
        int[] y = new int[this.k];
        for (Tree tree : this.trees) {
            int n = tree.tree.predict(x);
            y[n] = y[n] + 1;
        }
        return Math.whichMax(y);
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        if (posteriori.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
        }
        Arrays.fill(posteriori, 0.0);
        int[] y = new int[this.k];
        double[] pos = new double[this.k];
        for (Tree tree : this.trees) {
            int n = tree.tree.predict(x, pos);
            y[n] = y[n] + 1;
            for (int i = 0; i < this.k; ++i) {
                int n2 = i;
                posteriori[n2] = posteriori[n2] + tree.weight * pos[i];
            }
        }
        Math.unitize1(posteriori);
        return Math.whichMax(y);
    }

    public double[] test(double[][] x, int[] y) {
        int T = this.trees.size();
        double[] accuracy = new double[T];
        int n = x.length;
        int[] label = new int[n];
        int[][] prediction = new int[n][this.k];
        Accuracy measure = new Accuracy();
        for (int i = 0; i < T; ++i) {
            for (int j = 0; j < n; ++j) {
                int[] nArray = prediction[j];
                int n2 = this.trees.get((int)i).tree.predict(x[j]);
                nArray[n2] = nArray[n2] + 1;
                label[j] = Math.whichMax(prediction[j]);
            }
            accuracy[i] = measure.measure(y, label);
        }
        return accuracy;
    }

    public double[][] test(double[][] x, int[] y, ClassificationMeasure[] measures) {
        int T = this.trees.size();
        int m = measures.length;
        double[][] results = new double[T][m];
        int n = x.length;
        int[] label = new int[n];
        double[][] prediction = new double[n][this.k];
        for (int i = 0; i < T; ++i) {
            int j;
            for (j = 0; j < n; ++j) {
                double[] dArray = prediction[j];
                int n2 = this.trees.get((int)i).tree.predict(x[j]);
                dArray[n2] = dArray[n2] + 1.0;
                label[j] = Math.whichMax(prediction[j]);
            }
            for (j = 0; j < m; ++j) {
                results[i][j] = measures[j].measure(y, label);
            }
        }
        return results;
    }

    public DecisionTree[] getTrees() {
        DecisionTree[] forest = new DecisionTree[this.trees.size()];
        for (int i = 0; i < forest.length; ++i) {
            forest[i] = this.trees.get((int)i).tree;
        }
        return forest;
    }

    static class TrainingTask
    implements Callable<Tree> {
        Attribute[] attributes;
        double[][] x;
        int[] y;
        int mtry;
        int nodeSize = 5;
        int maxNodes = 100;
        double subsample = 1.0;
        DecisionTree.SplitRule rule;
        int[] classWeight;
        int[][] order;
        int[][] prediction;

        TrainingTask(Attribute[] attributes, double[][] x, int[] y, int maxNodes, int nodeSize, int mtry, double subsample, DecisionTree.SplitRule rule, int[] classWeight, int[][] order, int[][] prediction) {
            this.attributes = attributes;
            this.x = x;
            this.y = y;
            this.mtry = mtry;
            this.nodeSize = nodeSize;
            this.maxNodes = maxNodes;
            this.subsample = subsample;
            this.rule = rule;
            this.classWeight = classWeight;
            this.order = order;
            this.prediction = prediction;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public Tree call() {
            int i;
            int n = this.x.length;
            int k = Math.max(this.y) + 1;
            int[] samples = new int[n];
            if (this.subsample == 1.0) {
                for (int l = 0; l < k; ++l) {
                    int nj = 0;
                    ArrayList<Integer> cj = new ArrayList<Integer>();
                    for (i = 0; i < n; ++i) {
                        if (this.y[i] != l) continue;
                        cj.add(i);
                        ++nj;
                    }
                    int size = nj / this.classWeight[l];
                    for (int i2 = 0; i2 < size; ++i2) {
                        int xi = Math.randomInt(nj);
                        int n2 = (Integer)cj.get(xi);
                        samples[n2] = samples[n2] + 1;
                    }
                }
            } else {
                int[] perm = new int[n];
                for (int i3 = 0; i3 < n; ++i3) {
                    perm[i3] = i3;
                }
                Math.permutate(perm);
                int[] nc = new int[k];
                for (int i4 = 0; i4 < n; ++i4) {
                    int n3 = this.y[i4];
                    nc[n3] = nc[n3] + 1;
                }
                for (int l = 0; l < k; ++l) {
                    int subj = (int)Math.round((double)nc[l] * this.subsample / (double)this.classWeight[l]);
                    int count = 0;
                    for (int i5 = 0; i5 < n && count < subj; ++i5) {
                        int xi = perm[i5];
                        if (this.y[xi] != l) continue;
                        int n4 = xi;
                        samples[n4] = samples[n4] + 1;
                        ++count;
                    }
                }
            }
            DecisionTree tree = new DecisionTree(this.attributes, this.x, this.y, this.maxNodes, this.nodeSize, this.mtry, this.rule, (int[])samples.clone(), this.order);
            int oob = 0;
            int correct = 0;
            for (i = 0; i < n; ++i) {
                if (samples[i] != 0) continue;
                ++oob;
                int p = tree.predict(this.x[i]);
                if (p == this.y[i]) {
                    ++correct;
                }
                int[] nArray = this.prediction[i];
                synchronized (nArray) {
                    int[] nArray2 = this.prediction[i];
                    int n5 = p;
                    nArray2[n5] = nArray2[n5] + 1;
                    continue;
                }
            }
            double accuracy = 1.0;
            if (oob != 0) {
                accuracy = (double)correct / (double)oob;
                logger.info("Random forest tree OOB size: {}, accuracy: {}", (Object)oob, (Object)String.format("%.2f%%", 100.0 * accuracy));
            } else {
                logger.error("Random forest has a tree trained without OOB samples.");
            }
            return new Tree(tree, accuracy);
        }
    }

    public static class Trainer
    extends ClassifierTrainer<double[]> {
        private int ntrees = 500;
        private DecisionTree.SplitRule rule = DecisionTree.SplitRule.GINI;
        private int mtry = -1;
        private int nodeSize = 1;
        private int maxNodes = 100;
        private double subsample = 1.0;

        public Trainer() {
        }

        public Trainer(int ntrees) {
            if (ntrees < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
            }
            this.ntrees = ntrees;
        }

        public Trainer(Attribute[] attributes, int ntrees) {
            super(attributes);
            if (ntrees < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
            }
            this.ntrees = ntrees;
        }

        public Trainer setSplitRule(DecisionTree.SplitRule rule) {
            this.rule = rule;
            return this;
        }

        public Trainer setNumTrees(int ntrees) {
            if (ntrees < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
            }
            this.ntrees = ntrees;
            return this;
        }

        public Trainer setNumRandomFeatures(int mtry) {
            if (mtry < 1) {
                throw new IllegalArgumentException("Invalid number of random selected features for splitting: " + mtry);
            }
            this.mtry = mtry;
            return this;
        }

        public Trainer setMaxNodes(int maxNodes) {
            if (maxNodes < 2) {
                throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + maxNodes);
            }
            this.maxNodes = maxNodes;
            return this;
        }

        public Trainer setNodeSize(int nodeSize) {
            if (nodeSize < 1) {
                throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + nodeSize);
            }
            this.nodeSize = nodeSize;
            return this;
        }

        public Trainer setSamplingRates(double subsample) {
            if (subsample <= 0.0 || subsample > 1.0) {
                throw new IllegalArgumentException("Invalid sampling rating: " + subsample);
            }
            this.subsample = subsample;
            return this;
        }

        public RandomForest train(double[][] x, int[] y) {
            return new RandomForest(this.attributes, x, y, this.ntrees, this.maxNodes, this.nodeSize, this.mtry, this.subsample, this.rule, null);
        }
    }

    static class Tree
    implements Serializable {
        DecisionTree tree;
        double weight;

        Tree(DecisionTree tree, double weight) {
            this.tree = tree;
            this.weight = weight;
        }
    }
}

