/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.data.AttributeDataset;
import smile.data.parser.ArffParser;
import smile.data.parser.IOUtils;
import smile.math.Math;
import smile.regression.RandomForest;
import smile.sort.QuickSort;
import smile.validation.CrossValidation;
import smile.validation.LOOCV;
import smile.validation.Validation;

public class RandomForestTest {
    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testPredict() {
        System.out.println("predict");
        double[][] longley = new double[][]{{234.289, 235.6, 159.0, 107.608, 1947.0, 60.323}, {259.426, 232.5, 145.6, 108.632, 1948.0, 61.122}, {258.054, 368.2, 161.6, 109.773, 1949.0, 60.171}, {284.599, 335.1, 165.0, 110.929, 1950.0, 61.187}, {328.975, 209.9, 309.9, 112.075, 1951.0, 63.221}, {346.999, 193.2, 359.4, 113.27, 1952.0, 63.639}, {365.385, 187.0, 354.7, 115.094, 1953.0, 64.989}, {363.112, 357.8, 335.0, 116.219, 1954.0, 63.761}, {397.469, 290.4, 304.8, 117.388, 1955.0, 66.019}, {419.18, 282.2, 285.7, 118.734, 1956.0, 67.857}, {442.769, 293.6, 279.8, 120.445, 1957.0, 68.169}, {444.546, 468.1, 263.7, 121.95, 1958.0, 66.513}, {482.704, 381.3, 255.2, 123.366, 1959.0, 68.655}, {502.601, 393.1, 251.4, 125.368, 1960.0, 69.564}, {518.173, 480.6, 257.2, 127.852, 1961.0, 69.331}, {554.894, 400.7, 282.7, 130.081, 1962.0, 70.551}};
        double[] y = new double[]{83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9};
        int n = longley.length;
        LOOCV loocv = new LOOCV(n);
        double rss = 0.0;
        for (int i = 0; i < n; ++i) {
            double[][] trainx = (double[][])Math.slice(longley, loocv.train[i]);
            double[] trainy = Math.slice(y, loocv.train[i]);
            try {
                RandomForest forest = new RandomForest(trainx, trainy, 300, n, 3, 2);
                double r = y[loocv.test[i]] - forest.predict(longley[loocv.test[i]]);
                rss += r * r;
                continue;
            }
            catch (Exception ex) {
                System.err.println(ex);
            }
        }
        System.out.println("MSE = " + rss / (double)n);
    }

    public void test(String dataset, String url, int response) {
        System.out.println(dataset);
        ArffParser parser = new ArffParser();
        parser.setResponseIndex(response);
        try {
            AttributeDataset data = parser.parse(IOUtils.getTestDataFile(url));
            double[] datay = data.toArray(new double[data.size()]);
            double[][] datax = (double[][])data.toArray((E[])new double[data.size()][]);
            int n = datax.length;
            int k = 10;
            CrossValidation cv = new CrossValidation(n, k);
            double rss = 0.0;
            double ad = 0.0;
            for (int i = 0; i < k; ++i) {
                double[][] trainx = (double[][])Math.slice(datax, cv.train[i]);
                double[] trainy = Math.slice(datay, cv.train[i]);
                double[][] testx = (double[][])Math.slice(datax, cv.test[i]);
                double[] testy = Math.slice(datay, cv.test[i]);
                RandomForest forest = new RandomForest(data.attributes(), trainx, trainy, 200, n, 5, trainx[0].length / 3);
                System.out.format("OOB error rate = %.4f%n", forest.error());
                for (int j = 0; j < testx.length; ++j) {
                    double r = testy[j] - forest.predict(testx[j]);
                    rss += r * r;
                    ad += Math.abs(r);
                }
            }
            System.out.format("10-CV RMSE = %.4f \t AbsoluteDeviation = %.4f%n", Math.sqrt(rss / (double)n), ad / (double)n);
        }
        catch (Exception ex) {
            System.err.println(ex);
        }
    }

    @Test
    public void testAll() {
        this.test("CPU", "weka/cpu.arff", 6);
        this.test("autoMPG", "weka/regression/autoMpg.arff", 7);
    }

    @Test
    public void testCPU() {
        System.out.println("CPU");
        ArffParser parser = new ArffParser();
        parser.setResponseIndex(6);
        try {
            AttributeDataset data = parser.parse(IOUtils.getTestDataFile("weka/cpu.arff"));
            double[] datay = data.toArray(new double[data.size()]);
            double[][] datax = (double[][])data.toArray((E[])new double[data.size()][]);
            int n = datax.length;
            int m = 3 * n / 4;
            int[] index = Math.permutate(n);
            double[][] trainx = new double[m][];
            double[] trainy = new double[m];
            for (int i = 0; i < m; ++i) {
                trainx[i] = datax[index[i]];
                trainy[i] = datay[index[i]];
            }
            double[][] testx = new double[n - m][];
            double[] testy = new double[n - m];
            for (int i = m; i < n; ++i) {
                testx[i - m] = datax[index[i]];
                testy[i - m] = datay[index[i]];
            }
            RandomForest forest = new RandomForest(data.attributes(), trainx, trainy, 100, n, 5, trainx[0].length / 3);
            System.out.format("RMSE = %.4f%n", Validation.test(forest, testx, testy));
            double[] rmse = forest.test(testx, testy);
            for (int i = 1; i <= rmse.length; ++i) {
                System.out.format("%d trees RMSE = %.4f%n", i, rmse[i - 1]);
            }
            double[] importance = forest.importance();
            index = QuickSort.sort(importance);
            int i = importance.length;
            while (i-- > 0) {
                System.out.format("%s importance is %.4f%n", data.attributes()[index[i]], importance[i]);
            }
        }
        catch (Exception ex) {
            System.err.println(ex);
        }
    }
}

