/*
 * Decompiled with CFR 0.152.
 */
package Catalano.MachineLearning.Classification.DecisionTrees.Learning;

import Catalano.Core.ArraysUtil;
import Catalano.Core.Concurrent.MulticoreExecutor;
import Catalano.MachineLearning.Classification.DecisionTrees.DecisionTree;
import Catalano.MachineLearning.Classification.IClassifier;
import Catalano.MachineLearning.Dataset.DatasetClassification;
import Catalano.MachineLearning.Dataset.DecisionVariable;
import Catalano.Math.Matrix;
import Catalano.Math.Tools;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;

public class RandomForest
implements IClassifier,
Serializable {
    private DecisionVariable[] attributes;
    private int T;
    private int M;
    private RandomSelection rs;
    private DecisionTree.SplitRule rule;
    private List<DecisionTree> trees;
    private int k = 2;
    private double error;
    private double[] importance;

    public double error() {
        return this.error;
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.trees.size();
    }

    public DecisionTree.SplitRule getRule() {
        return this.rule;
    }

    public void setRule(DecisionTree.SplitRule rule) {
        this.rule = rule;
    }

    public RandomForest() {
        this(100);
    }

    public RandomForest(int T) {
        this(T, 0);
    }

    public RandomForest(int T, int M) {
        this(null, T, M);
    }

    public RandomForest(int T, int M, DecisionTree.SplitRule rule) {
        this(null, T, M, rule);
    }

    public RandomForest(int T, RandomSelection randomSelection) {
        this(null, T, randomSelection);
    }

    public RandomForest(int T, RandomSelection randomSelection, DecisionTree.SplitRule rule) {
        this(null, T, randomSelection, rule);
    }

    public RandomForest(DecisionVariable[] attributes) {
        this(attributes, 100);
    }

    public RandomForest(DecisionVariable[] attributes, int T) {
        this(attributes, T, Tools.Log2(attributes.length) + 1);
    }

    public RandomForest(DecisionVariable[] attributes, int T, int M) {
        this(attributes, T, M, DecisionTree.SplitRule.GINI);
    }

    public RandomForest(DecisionVariable[] attributes, int T, int M, DecisionTree.SplitRule rule) {
        this.attributes = attributes;
        this.T = T;
        this.M = M;
        this.rule = rule;
    }

    public RandomForest(DecisionVariable[] attributes, int T, RandomSelection randomSelection) {
        this(attributes, T, randomSelection, DecisionTree.SplitRule.GINI);
    }

    public RandomForest(DecisionVariable[] attributes, int T, RandomSelection randomSelection, DecisionTree.SplitRule rule) {
        this.attributes = attributes;
        this.T = T;
        this.rs = randomSelection;
        this.rule = rule;
    }

    private void BuildModel(DecisionVariable[] attributes, double[][] x, int[] y, int T, int M, DecisionTree.SplitRule rule) {
        int i;
        int i2;
        if (attributes == null) {
            int s = x[0].length;
            attributes = new DecisionVariable[s];
            for (i2 = 0; i2 < s; ++i2) {
                attributes[i2] = new DecisionVariable("F" + i2);
            }
        }
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (T < 1) {
            throw new IllegalArgumentException("Invlaid number of trees: " + T);
        }
        if (M < 1) {
            throw new IllegalArgumentException("Invalid number of variables for splitting: " + M);
        }
        int[] labels = Tools.Unique(y);
        Arrays.sort(labels);
        for (i2 = 0; i2 < labels.length; ++i2) {
            if (labels[i2] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i2]);
            }
            if (i2 <= 0 || labels[i2] - labels[i2 - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + labels[i2] + 1);
        }
        this.k = labels.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        int n = x.length;
        int[][] prediction = new int[n][this.k];
        int[][] order = this.sort(attributes, x);
        ArrayList<TrainingTask> tasks = new ArrayList<TrainingTask>();
        for (int i3 = 0; i3 < T; ++i3) {
            tasks.add(new TrainingTask(attributes, x, y, M, order, prediction, rule));
        }
        try {
            this.trees = MulticoreExecutor.run(tasks);
        }
        catch (Exception ex) {
            System.err.println(ex);
            this.trees = new ArrayList<DecisionTree>(T);
            for (i = 0; i < T; ++i) {
                this.trees.add(((TrainingTask)tasks.get(i)).call());
            }
        }
        int m = 0;
        for (i = 0; i < n; ++i) {
            int pred = Matrix.MaxIndex(prediction[i]);
            if (prediction[i][pred] <= 0) continue;
            ++m;
            if (pred == y[i]) continue;
            this.error += 1.0;
        }
        if (m > 0) {
            this.error /= (double)m;
        }
        this.importance = new double[attributes.length];
        for (DecisionTree tree : this.trees) {
            double[] imp = tree.getImportance();
            for (int i4 = 0; i4 < imp.length; ++i4) {
                int n2 = i4;
                this.importance[n2] = this.importance[n2] + imp[i4];
            }
        }
    }

    private int[][] sort(DecisionVariable[] attributes, double[][] x) {
        int n = x.length;
        int p = x[0].length;
        double[] a = new double[n];
        int[][] index = new int[p][];
        for (int j = 0; j < p; ++j) {
            if (attributes[j].type != DecisionVariable.Type.Continuous) continue;
            for (int i = 0; i < n; ++i) {
                a[i] = x[i][j];
            }
            index[j] = ArraysUtil.Argsort(a, true);
        }
        return index;
    }

    public void trim(int T) {
        if (T > this.trees.size()) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (T <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + T);
        }
        ArrayList<DecisionTree> model = new ArrayList<DecisionTree>(T);
        for (int i = 0; i < T; ++i) {
            model.add(this.trees.get(i));
        }
        this.trees = model;
    }

    @Override
    public void Learn(DatasetClassification dataset) {
        this.Learn(dataset.getInput(), dataset.getOutput());
    }

    @Override
    public void Learn(double[][] input, int[] output) {
        if (this.M == 0) {
            this.M = this.rs == RandomSelection.Sqrt ? (int)Math.floor(Math.sqrt(input[0].length)) : (int)Tools.Log(input[0].length, 2.0) + 1;
        }
        this.BuildModel(this.attributes, input, output, this.T, this.M, this.rule);
    }

    @Override
    public int Predict(double[] feature) {
        int[] y = new int[this.k];
        for (DecisionTree tree : this.trees) {
            int n = tree.Predict(feature);
            y[n] = y[n] + 1;
        }
        return Matrix.MaxIndex(y);
    }

    public int Predict(double[] feature, double[] posteriori) {
        if (posteriori.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
        }
        int[] y = new int[this.k];
        for (DecisionTree tree : this.trees) {
            int n = tree.Predict(feature);
            y[n] = y[n] + 1;
        }
        double n = this.trees.size();
        for (int i = 0; i < this.k; ++i) {
            posteriori[i] = (double)y[i] / n;
        }
        return Matrix.MaxIndex(y);
    }

    @Override
    public IClassifier clone() {
        try {
            return (IClassifier)super.clone();
        }
        catch (CloneNotSupportedException ex) {
            throw new IllegalArgumentException("Clone not supported: " + ex.getMessage());
        }
    }

    static class TrainingTask
    implements Callable<DecisionTree> {
        DecisionVariable[] attributes;
        double[][] x;
        int[] y;
        int[][] order;
        int M;
        int[][] prediction;
        DecisionTree.SplitRule rule;

        TrainingTask(DecisionVariable[] attributes, double[][] x, int[] y, int M, int[][] order, int[][] prediction, DecisionTree.SplitRule rule) {
            this.attributes = attributes;
            this.x = x;
            this.y = y;
            this.order = order;
            this.M = M;
            this.prediction = prediction;
            this.rule = rule;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public DecisionTree call() {
            int n = this.x.length;
            Random random = new Random(Thread.currentThread().getId() * System.currentTimeMillis());
            int[] samples = new int[n];
            for (int i = 0; i < n; ++i) {
                int n2 = random.nextInt(n);
                samples[n2] = samples[n2] + 1;
            }
            DecisionTree tree = new DecisionTree(this.attributes, this.x, this.y, this.M, samples, this.order, this.rule);
            for (int i = 0; i < n; ++i) {
                if (samples[i] != 0) continue;
                int p = tree.Predict(this.x[i]);
                int[] nArray = this.prediction[i];
                synchronized (nArray) {
                    int[] nArray2 = this.prediction[i];
                    int n3 = p;
                    nArray2[n3] = nArray2[n3] + 1;
                    continue;
                }
            }
            return tree;
        }
    }

    public static enum RandomSelection {
        Sqrt,
        Log;

    }
}

