/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.data.AttributeDataset;
import smile.data.parser.ArffParser;
import smile.data.parser.IOUtils;
import smile.math.Math;
import smile.regression.GradientTreeBoost;
import smile.sort.QuickSort;
import smile.validation.CrossValidation;
import smile.validation.Validation;

public class GradientTreeBoostTest {
    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    public void test(GradientTreeBoost.Loss loss, String dataset, String url, int response) {
        System.out.println(dataset + "\t" + (Object)((Object)loss));
        ArffParser parser = new ArffParser();
        parser.setResponseIndex(response);
        try {
            AttributeDataset data = parser.parse(IOUtils.getTestDataFile(url));
            double[] datay = data.toArray(new double[data.size()]);
            double[][] datax = (double[][])data.toArray((E[])new double[data.size()][]);
            int n = datax.length;
            int k = 10;
            CrossValidation cv = new CrossValidation(n, k);
            double rss = 0.0;
            double ad = 0.0;
            for (int i = 0; i < k; ++i) {
                double[][] trainx = (double[][])Math.slice(datax, cv.train[i]);
                double[] trainy = Math.slice(datay, cv.train[i]);
                double[][] testx = (double[][])Math.slice(datax, cv.test[i]);
                double[] testy = Math.slice(datay, cv.test[i]);
                GradientTreeBoost boost = new GradientTreeBoost(data.attributes(), trainx, trainy, loss, 100, 6, 0.05, 0.7);
                for (int j = 0; j < testx.length; ++j) {
                    double r = testy[j] - boost.predict(testx[j]);
                    ad += Math.abs(r);
                    rss += r * r;
                }
            }
            System.out.format("10-CV RMSE = %.4f \t AbsoluteDeviation = %.4f%n", Math.sqrt(rss / (double)n), ad / (double)n);
        }
        catch (Exception ex) {
            System.err.println(ex);
        }
    }

    @Test
    public void testLS() {
        this.test(GradientTreeBoost.Loss.LeastSquares, "CPU", "weka/cpu.arff", 6);
        this.test(GradientTreeBoost.Loss.LeastSquares, "autoMPG", "weka/regression/autoMpg.arff", 7);
        this.test(GradientTreeBoost.Loss.LeastSquares, "cal_housing", "weka/regression/cal_housing.arff", 8);
    }

    @Test
    public void testLAD() {
        this.test(GradientTreeBoost.Loss.LeastAbsoluteDeviation, "CPU", "weka/cpu.arff", 6);
        this.test(GradientTreeBoost.Loss.LeastAbsoluteDeviation, "autoMPG", "weka/regression/autoMpg.arff", 7);
        this.test(GradientTreeBoost.Loss.LeastAbsoluteDeviation, "cal_housing", "weka/regression/cal_housing.arff", 8);
    }

    @Test
    public void testHuber() {
        this.test(GradientTreeBoost.Loss.Huber, "CPU", "weka/cpu.arff", 6);
        this.test(GradientTreeBoost.Loss.Huber, "autoMPG", "weka/regression/autoMpg.arff", 7);
        this.test(GradientTreeBoost.Loss.Huber, "cal_housing", "weka/regression/cal_housing.arff", 8);
    }

    @Test
    public void testCPU() {
        System.out.println("CPU");
        ArffParser parser = new ArffParser();
        parser.setResponseIndex(6);
        try {
            AttributeDataset data = parser.parse(IOUtils.getTestDataFile("weka/cpu.arff"));
            double[] datay = data.toArray(new double[data.size()]);
            double[][] datax = (double[][])data.toArray((E[])new double[data.size()][]);
            int n = datax.length;
            int m = 3 * n / 4;
            int[] index = Math.permutate(n);
            double[][] trainx = new double[m][];
            double[] trainy = new double[m];
            for (int i = 0; i < m; ++i) {
                trainx[i] = datax[index[i]];
                trainy[i] = datay[index[i]];
            }
            double[][] testx = new double[n - m][];
            double[] testy = new double[n - m];
            for (int i = m; i < n; ++i) {
                testx[i - m] = datax[index[i]];
                testy[i - m] = datay[index[i]];
            }
            GradientTreeBoost boost = new GradientTreeBoost(data.attributes(), trainx, trainy, 100);
            System.out.format("RMSE = %.4f%n", Validation.test(boost, testx, testy));
            double[] rmse = boost.test(testx, testy);
            for (int i = 1; i <= rmse.length; ++i) {
                System.out.format("%d trees RMSE = %.4f%n", i, rmse[i - 1]);
            }
            double[] importance = boost.importance();
            index = QuickSort.sort(importance);
            int i = importance.length;
            while (i-- > 0) {
                System.out.format("%s importance is %.4f%n", data.attributes()[index[i]], importance[i]);
            }
        }
        catch (Exception ex) {
            System.err.println(ex);
        }
    }
}

