/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import smile.classification.ClassifierTrainer;
import smile.classification.SoftClassifier;
import smile.data.Attribute;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.regression.RegressionTree;
import smile.util.SmileUtils;
import smile.validation.Accuracy;
import smile.validation.ClassificationMeasure;

public class GradientTreeBoost
implements SoftClassifier<double[]>,
Serializable {
    private static final long serialVersionUID = 1L;
    private int k = 2;
    private RegressionTree[] trees;
    private RegressionTree[][] forest;
    private double[] importance;
    private double b = 0.0;
    private double shrinkage = 0.005;
    private int maxNodes = 6;
    private int ntrees = 500;
    private double subsample = 0.7;

    public GradientTreeBoost(double[][] x, int[] y, int ntrees) {
        this(null, x, y, ntrees);
    }

    public GradientTreeBoost(double[][] x, int[] y, int ntrees, int maxNodes, double shrinkage, double f) {
        this(null, x, y, ntrees, maxNodes, shrinkage, f);
    }

    public GradientTreeBoost(Attribute[] attributes, double[][] x, int[] y, int ntrees) {
        this(attributes, x, y, ntrees, 6, x.length < 2000 ? 0.005 : 0.05, 0.7);
    }

    public GradientTreeBoost(Attribute[] attributes, double[][] x, int[] y, int ntrees, int maxNodes, double shrinkage, double subsample) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (ntrees < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
        }
        if (maxNodes < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + maxNodes);
        }
        if (shrinkage <= 0.0 || shrinkage > 1.0) {
            throw new IllegalArgumentException("Invalid shrinkage: " + shrinkage);
        }
        if (subsample <= 0.0 || subsample > 1.0) {
            throw new IllegalArgumentException("Invalid sampling fraction: " + subsample);
        }
        if (attributes == null) {
            int p = x[0].length;
            attributes = new Attribute[p];
            for (int i = 0; i < p; ++i) {
                attributes[i] = new NumericAttribute("V" + (i + 1));
            }
        }
        this.ntrees = ntrees;
        this.maxNodes = maxNodes;
        this.shrinkage = shrinkage;
        this.subsample = subsample;
        this.k = Math.max(y) + 1;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class or negative class labels.");
        }
        this.importance = new double[attributes.length];
        if (this.k == 2) {
            this.train2(attributes, x, y);
            for (RegressionTree tree : this.trees) {
                double[] imp = tree.importance();
                for (int i = 0; i < imp.length; ++i) {
                    int n = i;
                    this.importance[n] = this.importance[n] + imp[i];
                }
            }
        } else {
            this.traink(attributes, x, y);
            RegressionTree[][] regressionTreeArray = this.forest;
            int n = regressionTreeArray.length;
            for (int i = 0; i < n; ++i) {
                RegressionTree[] grove;
                for (RegressionTree tree : grove = regressionTreeArray[i]) {
                    double[] imp = tree.importance();
                    for (int i2 = 0; i2 < imp.length; ++i2) {
                        int n2 = i2;
                        this.importance[n2] = this.importance[n2] + imp[i2];
                    }
                }
            }
        }
    }

    public double[] importance() {
        return this.importance;
    }

    private void train2(Attribute[] attributes, double[][] x, int[] y) {
        int n = x.length;
        int[] nc = new int[this.k];
        for (int i = 0; i < n; ++i) {
            int n2 = y[i];
            nc[n2] = nc[n2] + 1;
        }
        int[] y2 = new int[n];
        for (int i = 0; i < n; ++i) {
            y2[i] = y[i] == 1 ? 1 : -1;
        }
        double[] h = new double[n];
        double[] response = new double[n];
        double mu = Math.mean(y2);
        this.b = 0.5 * Math.log((1.0 + mu) / (1.0 - mu));
        for (int i = 0; i < n; ++i) {
            h[i] = this.b;
        }
        int[][] order = SmileUtils.sort(attributes, x);
        L2NodeOutput output = new L2NodeOutput(response);
        this.trees = new RegressionTree[this.ntrees];
        int[] perm = new int[n];
        int[] samples = new int[n];
        for (int i = 0; i < n; ++i) {
            perm[i] = i;
        }
        for (int m = 0; m < this.ntrees; ++m) {
            int i;
            Arrays.fill(samples, 0);
            Math.permutate(perm);
            for (int l = 0; l < this.k; ++l) {
                int subj = (int)Math.round((double)nc[l] * this.subsample);
                int count = 0;
                for (int i2 = 0; i2 < n && count < subj; ++i2) {
                    int xi = perm[i2];
                    if (y[xi] != l) continue;
                    samples[xi] = 1;
                    ++count;
                }
            }
            for (i = 0; i < n; ++i) {
                response[i] = 2.0 * (double)y2[i] / (1.0 + Math.exp((double)(2 * y2[i]) * h[i]));
            }
            this.trees[m] = new RegressionTree(attributes, x, response, this.maxNodes, 5, x[0].length, order, samples, output);
            for (i = 0; i < n; ++i) {
                int n3 = i;
                h[n3] = h[n3] + this.shrinkage * this.trees[m].predict(x[i]);
            }
        }
    }

    private void traink(Attribute[] attributes, double[][] x, int[] y) {
        int n = x.length;
        int[] nc = new int[this.k];
        for (int i = 0; i < n; ++i) {
            int n2 = y[i];
            nc[n2] = nc[n2] + 1;
        }
        double[][] h = new double[this.k][n];
        double[][] p = new double[this.k][n];
        double[][] response = new double[this.k][n];
        int[][] order = SmileUtils.sort(attributes, x);
        this.forest = new RegressionTree[this.k][this.ntrees];
        LKNodeOutput[] output = new LKNodeOutput[this.k];
        for (int i = 0; i < this.k; ++i) {
            output[i] = new LKNodeOutput(response[i]);
        }
        int[] perm = new int[n];
        int[] samples = new int[n];
        for (int i = 0; i < n; ++i) {
            perm[i] = i;
        }
        for (int m = 0; m < this.ntrees; ++m) {
            for (int i = 0; i < n; ++i) {
                int j;
                double max = Double.NEGATIVE_INFINITY;
                for (int j2 = 0; j2 < this.k; ++j2) {
                    if (!(max < h[j2][i])) continue;
                    max = h[j2][i];
                }
                double Z = 0.0;
                for (j = 0; j < this.k; ++j) {
                    p[j][i] = Math.exp(h[j][i] - max);
                    Z += p[j][i];
                }
                for (j = 0; j < this.k; ++j) {
                    double[] dArray = p[j];
                    int n3 = i;
                    dArray[n3] = dArray[n3] / Z;
                }
            }
            for (int j = 0; j < this.k; ++j) {
                int i;
                for (i = 0; i < n; ++i) {
                    response[j][i] = y[i] == j ? 1.0 : 0.0;
                    double[] dArray = response[j];
                    int n4 = i;
                    dArray[n4] = dArray[n4] - p[j][i];
                }
                Arrays.fill(samples, 0);
                Math.permutate(perm);
                for (int l = 0; l < this.k; ++l) {
                    int subj = (int)Math.round((double)nc[l] * this.subsample);
                    int count = 0;
                    for (int i2 = 0; i2 < n && count < subj; ++i2) {
                        int xi = perm[i2];
                        if (y[xi] != l) continue;
                        samples[xi] = 1;
                        ++count;
                    }
                }
                this.forest[j][m] = new RegressionTree(attributes, x, response[j], this.maxNodes, 5, x[0].length, order, samples, output[j]);
                for (i = 0; i < n; ++i) {
                    double[] dArray = h[j];
                    int n5 = i;
                    dArray[n5] = dArray[n5] + this.shrinkage * this.forest[j][m].predict(x[i]);
                }
            }
        }
    }

    public int size() {
        return this.trees.length;
    }

    public void trim(int ntrees) {
        if (ntrees < 1) {
            throw new IllegalArgumentException("Invalid new model size: " + ntrees);
        }
        if (this.k == 2) {
            if (ntrees > this.trees.length) {
                throw new IllegalArgumentException("The new model size is larger than the current size.");
            }
            if (ntrees < this.trees.length) {
                this.trees = Arrays.copyOf(this.trees, ntrees);
                this.ntrees = ntrees;
            }
        } else {
            if (ntrees > this.forest[0].length) {
                throw new IllegalArgumentException("The new model size is larger than the current one.");
            }
            if (ntrees < this.forest[0].length) {
                for (int i = 0; i < this.forest.length; ++i) {
                    this.forest[i] = Arrays.copyOf(this.forest[i], ntrees);
                }
                this.ntrees = ntrees;
            }
        }
    }

    @Override
    public int predict(double[] x) {
        if (this.k == 2) {
            double y = this.b;
            for (int i = 0; i < this.ntrees; ++i) {
                y += this.shrinkage * this.trees[i].predict(x);
            }
            return y > 0.0 ? 1 : 0;
        }
        double max = Double.NEGATIVE_INFINITY;
        int y = -1;
        for (int j = 0; j < this.k; ++j) {
            double yj = 0.0;
            for (int i = 0; i < this.ntrees; ++i) {
                yj += this.shrinkage * this.forest[j][i].predict(x);
            }
            if (!(yj > max)) continue;
            max = yj;
            y = j;
        }
        return y;
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        int i;
        if (posteriori.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
        }
        if (this.k == 2) {
            double y = this.b;
            for (int i2 = 0; i2 < this.ntrees; ++i2) {
                y += this.shrinkage * this.trees[i2].predict(x);
            }
            posteriori[0] = 1.0 / (1.0 + Math.exp(2.0 * y));
            posteriori[1] = 1.0 - posteriori[0];
            if (y > 0.0) {
                return 1;
            }
            return 0;
        }
        double max = Double.NEGATIVE_INFINITY;
        int y = -1;
        for (int j = 0; j < this.k; ++j) {
            posteriori[j] = 0.0;
            for (int i3 = 0; i3 < this.ntrees; ++i3) {
                int n = j;
                posteriori[n] = posteriori[n] + this.shrinkage * this.forest[j][i3].predict(x);
            }
            if (!(posteriori[j] > max)) continue;
            max = posteriori[j];
            y = j;
        }
        double Z = 0.0;
        for (i = 0; i < this.k; ++i) {
            posteriori[i] = Math.exp(posteriori[i] - max);
            Z += posteriori[i];
        }
        i = 0;
        while (i < this.k) {
            int n = i++;
            posteriori[n] = posteriori[n] / Z;
        }
        return y;
    }

    public double[] test(double[][] x, int[] y) {
        double[] accuracy = new double[this.ntrees];
        int n = x.length;
        int[] label = new int[n];
        Accuracy measure = new Accuracy();
        if (this.k == 2) {
            double[] prediction = new double[n];
            Arrays.fill(prediction, this.b);
            for (int i = 0; i < this.ntrees; ++i) {
                for (int j = 0; j < n; ++j) {
                    int n2 = j;
                    prediction[n2] = prediction[n2] + this.shrinkage * this.trees[i].predict(x[j]);
                    label[j] = prediction[j] > 0.0 ? 1 : 0;
                }
                accuracy[i] = measure.measure(y, label);
            }
        } else {
            double[][] prediction = new double[n][this.k];
            for (int i = 0; i < this.ntrees; ++i) {
                for (int j = 0; j < n; ++j) {
                    for (int l = 0; l < this.k; ++l) {
                        double[] dArray = prediction[j];
                        int n3 = l;
                        dArray[n3] = dArray[n3] + this.shrinkage * this.forest[l][i].predict(x[j]);
                    }
                    label[j] = Math.whichMax(prediction[j]);
                }
                accuracy[i] = measure.measure(y, label);
            }
        }
        return accuracy;
    }

    public double[][] test(double[][] x, int[] y, ClassificationMeasure[] measures) {
        int m = measures.length;
        double[][] results = new double[this.ntrees][m];
        int n = x.length;
        int[] label = new int[n];
        if (this.k == 2) {
            double[] prediction = new double[n];
            Arrays.fill(prediction, this.b);
            for (int i = 0; i < this.ntrees; ++i) {
                int j;
                for (j = 0; j < n; ++j) {
                    int n2 = j;
                    prediction[n2] = prediction[n2] + this.shrinkage * this.trees[i].predict(x[j]);
                    label[j] = prediction[j] > 0.0 ? 1 : 0;
                }
                for (j = 0; j < m; ++j) {
                    results[i][j] = measures[j].measure(y, label);
                }
            }
        } else {
            double[][] prediction = new double[n][this.k];
            for (int i = 0; i < this.ntrees; ++i) {
                int j;
                for (j = 0; j < n; ++j) {
                    for (int l = 0; l < this.k; ++l) {
                        double[] dArray = prediction[j];
                        int n3 = l;
                        dArray[n3] = dArray[n3] + this.shrinkage * this.forest[l][i].predict(x[j]);
                    }
                    label[j] = Math.whichMax(prediction[j]);
                }
                for (j = 0; j < m; ++j) {
                    results[i][j] = measures[j].measure(y, label);
                }
            }
        }
        return results;
    }

    public RegressionTree[] getTrees() {
        return this.trees;
    }

    class LKNodeOutput
    implements RegressionTree.NodeOutput {
        double[] y;

        public LKNodeOutput(double[] response) {
            this.y = response;
        }

        @Override
        public double calculate(int[] samples) {
            int n = 0;
            double nu = 0.0;
            double de = 0.0;
            for (int i = 0; i < samples.length; ++i) {
                if (samples[i] <= 0) continue;
                ++n;
                double abs = Math.abs(this.y[i]);
                nu += this.y[i];
                de += abs * (1.0 - abs);
            }
            if (de < 1.0E-10) {
                return nu / (double)n;
            }
            return ((double)GradientTreeBoost.this.k - 1.0) / (double)GradientTreeBoost.this.k * (nu / de);
        }
    }

    class L2NodeOutput
    implements RegressionTree.NodeOutput {
        double[] y;

        public L2NodeOutput(double[] y) {
            this.y = y;
        }

        @Override
        public double calculate(int[] samples) {
            double nu = 0.0;
            double de = 0.0;
            for (int i = 0; i < samples.length; ++i) {
                if (samples[i] <= 0) continue;
                double abs = Math.abs(this.y[i]);
                nu += this.y[i];
                de += abs * (2.0 - abs);
            }
            return nu / de;
        }
    }

    public static class Trainer
    extends ClassifierTrainer<double[]> {
        private int ntrees = 500;
        private double shrinkage = 0.005;
        private int maxNodes = 6;
        private double subsample = 0.7;

        public Trainer() {
        }

        public Trainer(int ntrees) {
            if (ntrees < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
            }
            this.ntrees = ntrees;
        }

        public Trainer(Attribute[] attributes, int ntrees) {
            super(attributes);
            if (ntrees < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
            }
            this.ntrees = ntrees;
        }

        public Trainer setNumTrees(int ntrees) {
            if (ntrees < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
            }
            this.ntrees = ntrees;
            return this;
        }

        public Trainer setMaxNodes(int maxNodes) {
            if (maxNodes < 2) {
                throw new IllegalArgumentException("Invalid maximum number of leaf nodes: " + maxNodes);
            }
            this.maxNodes = maxNodes;
            return this;
        }

        public Trainer setShrinkage(double shrinkage) {
            if (shrinkage <= 0.0 || shrinkage > 1.0) {
                throw new IllegalArgumentException("Invalid shrinkage: " + shrinkage);
            }
            this.shrinkage = shrinkage;
            return this;
        }

        public Trainer setSamplingRates(double subsample) {
            if (subsample <= 0.0 || subsample > 1.0) {
                throw new IllegalArgumentException("Invalid sampling fraction: " + subsample);
            }
            this.subsample = subsample;
            return this;
        }

        public GradientTreeBoost train(double[][] x, int[] y) {
            return new GradientTreeBoost(this.attributes, x, y, this.ntrees, this.maxNodes, this.shrinkage, this.subsample);
        }
    }
}

