/*
 * Decompiled with CFR 0.152.
 */
package Catalano.MachineLearning.Classification.DecisionTrees.Learning;

import Catalano.Core.ArraysUtil;
import Catalano.MachineLearning.Classification.DecisionTrees.DecisionTree;
import Catalano.MachineLearning.Classification.IClassifier;
import Catalano.MachineLearning.Dataset.DatasetClassification;
import Catalano.MachineLearning.Dataset.DecisionVariable;
import Catalano.Math.Matrix;
import Catalano.Math.Tools;
import java.io.Serializable;
import java.util.Arrays;

public class AdaBoost
implements IClassifier,
Serializable {
    private DecisionVariable[] attributes;
    private int T;
    private int J;
    private int k;
    private DecisionTree[] trees;
    private double[] alpha;
    private double[] error;
    private double[] importance;

    public double[] getImportance() {
        return this.importance;
    }

    public int size() {
        return this.trees.length;
    }

    public int getNumberOfTrees() {
        return this.T;
    }

    public void setNumberOfTrees(int T) {
        this.T = T;
    }

    public int getNumberOfLeafs() {
        return this.J;
    }

    public void setNumberOfLeafs(int J) {
        this.J = J;
    }

    public AdaBoost() {
        this(10);
    }

    public AdaBoost(int T) {
        this(T, 2);
    }

    public AdaBoost(int T, int J) {
        this.T = T;
        this.J = J;
    }

    public AdaBoost(DecisionVariable[] attributes) {
        this(null, 10);
    }

    public AdaBoost(DecisionVariable[] attributes, int T) {
        this(attributes, T, 2);
    }

    public AdaBoost(DecisionVariable[] attributes, int T, int J) {
        this.attributes = attributes;
        this.T = T;
        this.J = J;
    }

    private void BuildModel(DecisionVariable[] attributes, double[][] x, int[] y, int T, int J) {
        this.T = T;
        this.J = J;
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (T < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + T);
        }
        if (J < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + J);
        }
        int[] labels = Tools.Unique(y);
        Arrays.sort(labels);
        for (int i = 0; i < labels.length; ++i) {
            if (labels[i] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i]);
            }
            if (i <= 0 || labels[i] - labels[i - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + labels[i] + 1);
        }
        this.k = labels.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (attributes == null) {
            int s = x[0].length;
            attributes = new DecisionVariable[s];
            for (int i = 0; i < s; ++i) {
                attributes[i] = new DecisionVariable("F" + i);
            }
        }
        int[][] order = this.sort(attributes, x);
        int n = x.length;
        int[] samples = new int[n];
        double[] w = new double[n];
        boolean[] err = new boolean[n];
        for (int i = 0; i < n; ++i) {
            w[i] = 1.0;
        }
        double guess = 1.0 / (double)this.k;
        double b = Math.log(this.k - 1);
        this.trees = new DecisionTree[T];
        this.alpha = new double[T];
        this.error = new double[T];
        for (int t = 0; t < T; ++t) {
            int[] rand;
            double W = Tools.Sum(w);
            int i = 0;
            while (i < n) {
                int n2 = i++;
                w[n2] = w[n2] / W;
            }
            Arrays.fill(samples, 0);
            int[] nArray = rand = this.random(w, n);
            int n3 = nArray.length;
            for (int j = 0; j < n3; ++j) {
                int s;
                int n4 = s = nArray[j];
                samples[n4] = samples[n4] + 1;
            }
            this.trees[t] = new DecisionTree(attributes, J, samples, order, DecisionTree.SplitRule.GINI);
            this.trees[t].Learn(x, y);
            for (int i2 = 0; i2 < n; ++i2) {
                err[i2] = this.trees[t].Predict(x[i2]) != y[i2];
            }
            double e = 0.0;
            for (int i3 = 0; i3 < n; ++i3) {
                if (!err[i3]) continue;
                e += w[i3];
            }
            if (1.0 - e <= guess) {
                this.trees = Arrays.copyOf(this.trees, t);
                this.alpha = Arrays.copyOf(this.alpha, t);
                this.error = Arrays.copyOf(this.error, t);
                break;
            }
            this.error[t] = e;
            this.alpha[t] = Math.log((1.0 - e) / Math.max(1.0E-10, e)) + b;
            double a = Math.exp(this.alpha[t]);
            for (int i4 = 0; i4 < n; ++i4) {
                if (!err[i4]) continue;
                int n5 = i4;
                w[n5] = w[n5] * a;
            }
        }
        this.importance = new double[attributes.length];
        for (DecisionTree tree : this.trees) {
            double[] imp = tree.getImportance();
            for (int i = 0; i < imp.length; ++i) {
                int n6 = i;
                this.importance[n6] = this.importance[n6] + imp[i];
            }
        }
    }

    @Override
    public void Learn(DatasetClassification dataset) {
        this.Learn(dataset.getInput(), dataset.getOutput());
    }

    @Override
    public void Learn(double[][] input, int[] output) {
        this.BuildModel(this.attributes, input, output, this.T, this.J);
    }

    @Override
    public int Predict(double[] feature) {
        if (this.k == 2) {
            double y = 0.0;
            for (int i = 0; i < this.trees.length; ++i) {
                y += this.alpha[i] * (double)this.trees[i].Predict(feature);
            }
            return y > 0.0 ? 1 : 0;
        }
        double[] y = new double[this.k];
        for (int i = 0; i < this.trees.length; ++i) {
            int n = this.trees[i].Predict(feature);
            y[n] = y[n] + this.alpha[i];
        }
        return Matrix.MaxIndex(y);
    }

    private int[] random(double[] prob, int n) {
        double[] q = new double[prob.length];
        for (int i = 0; i < prob.length; ++i) {
            q[i] = prob[i] * (double)prob.length;
        }
        int[] a = new int[prob.length];
        for (int i = 0; i < prob.length; ++i) {
            a[i] = i;
        }
        int[] HL = new int[prob.length];
        int head = 0;
        int tail = prob.length - 1;
        for (int i = 0; i < prob.length; ++i) {
            if (q[i] >= 1.0) {
                HL[head++] = i;
                continue;
            }
            HL[tail--] = i;
        }
        while (head != 0 && tail != prob.length - 1) {
            int k;
            int j = HL[tail + 1];
            a[j] = k = HL[head - 1];
            int n2 = k;
            q[n2] = q[n2] + (q[j] - 1.0);
            ++tail;
            if (!(q[k] < 1.0)) continue;
            HL[tail--] = k;
            --head;
        }
        int[] ans = new int[n];
        for (int i = 0; i < n; ++i) {
            int k;
            double rU = Tools.RandomNextDouble() * (double)prob.length;
            ans[i] = (rU -= (double)(k = (int)rU)) < q[k] ? k : a[k];
        }
        return ans;
    }

    public void trim(int T) {
        if (T > this.trees.length) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (T <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + T);
        }
        if (T < this.trees.length) {
            this.trees = Arrays.copyOf(this.trees, T);
            this.alpha = Arrays.copyOf(this.alpha, T);
            this.error = Arrays.copyOf(this.error, T);
        }
    }

    private int[][] sort(DecisionVariable[] attributes, double[][] x) {
        int n = x.length;
        int p = x[0].length;
        double[] a = new double[n];
        int[][] index = new int[p][];
        for (int j = 0; j < p; ++j) {
            if (attributes[j].type != DecisionVariable.Type.Continuous) continue;
            for (int i = 0; i < n; ++i) {
                a[i] = x[i][j];
            }
            index[j] = ArraysUtil.Argsort(a, true);
        }
        return index;
    }

    @Override
    public IClassifier clone() {
        try {
            return (IClassifier)super.clone();
        }
        catch (CloneNotSupportedException ex) {
            throw new IllegalArgumentException("Clone not supported: " + ex.getMessage());
        }
    }
}

