/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.data.AttributeDataset;
import smile.data.parser.ArffParser;
import smile.data.parser.IOUtils;
import smile.math.Math;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.Matrix;
import smile.regression.LASSO;
import smile.validation.CrossValidation;
import smile.validation.LOOCV;

public class LASSOTest {
    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testToy2() {
        double[][] A = new double[][]{{1.0, 0.0, 0.0, 0.5}, {0.0, 1.0, 0.2, 0.3}, {1.0, 0.5, 0.2, 0.3}, {0.0, 0.1, 0.0, 0.2}, {0.0, 0.1, 1.0, 0.2}};
        double[] x0 = new double[]{1.0, 0.0, 1.0, 0.0};
        double[] y = new double[A.length];
        DenseMatrix a = Matrix.newInstance(A);
        a.ax(x0, y);
        int i = 0;
        while (i < y.length) {
            int n = i++;
            y[n] = y[n] + 5.0;
        }
        LASSO lasso = new LASSO(A, y, 0.1, 0.001, 500);
        double rss = 0.0;
        int n = A.length;
        for (int i2 = 0; i2 < n; ++i2) {
            double r = y[i2] - lasso.predict(A[i2]);
            rss += r * r;
        }
        System.out.println("MSE = " + rss / (double)n);
        Assert.assertEquals(5.0259443688265355, lasso.intercept(), 1.0E-7);
        double[] w = new double[]{0.9659945126777854, -3.7147706312985876E-4, 0.9553629503697613, 9.416740009376934E-4};
        for (int i3 = 0; i3 < w.length; ++i3) {
            Assert.assertEquals(w[i3], lasso.coefficients()[i3], 1.0E-5);
        }
    }

    @Test
    public void testLongley() {
        System.out.println("longley");
        double[][] longley = new double[][]{{234.289, 235.6, 159.0, 107.608, 1947.0, 60.323}, {259.426, 232.5, 145.6, 108.632, 1948.0, 61.122}, {258.054, 368.2, 161.6, 109.773, 1949.0, 60.171}, {284.599, 335.1, 165.0, 110.929, 1950.0, 61.187}, {328.975, 209.9, 309.9, 112.075, 1951.0, 63.221}, {346.999, 193.2, 359.4, 113.27, 1952.0, 63.639}, {365.385, 187.0, 354.7, 115.094, 1953.0, 64.989}, {363.112, 357.8, 335.0, 116.219, 1954.0, 63.761}, {397.469, 290.4, 304.8, 117.388, 1955.0, 66.019}, {419.18, 282.2, 285.7, 118.734, 1956.0, 67.857}, {442.769, 293.6, 279.8, 120.445, 1957.0, 68.169}, {444.546, 468.1, 263.7, 121.95, 1958.0, 66.513}, {482.704, 381.3, 255.2, 123.366, 1959.0, 68.655}, {502.601, 393.1, 251.4, 125.368, 1960.0, 69.564}, {518.173, 480.6, 257.2, 127.852, 1961.0, 69.331}, {554.894, 400.7, 282.7, 130.081, 1962.0, 70.551}};
        double[] y = new double[]{83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9};
        double rss = 0.0;
        int n = longley.length;
        LOOCV loocv = new LOOCV(n);
        for (int i = 0; i < n; ++i) {
            double[][] trainx = (double[][])Math.slice(longley, loocv.train[i]);
            double[] trainy = Math.slice(y, loocv.train[i]);
            LASSO lasso = new LASSO(trainx, trainy, 0.1);
            double r = y[loocv.test[i]] - lasso.predict(longley[loocv.test[i]]);
            rss += r * r;
        }
        System.out.println("LOOCV MSE = " + rss / (double)n);
        Assert.assertEquals(2.0012529348358212, rss / (double)n, 1.0E-4);
    }

    @Test
    public void testCPU() {
        System.out.println("CPU");
        ArffParser parser = new ArffParser();
        parser.setResponseIndex(6);
        try {
            AttributeDataset data = parser.parse(IOUtils.getTestDataFile("weka/cpu.arff"));
            double[][] datax = (double[][])data.toArray((E[])new double[data.size()][]);
            double[] datay = data.toArray(new double[data.size()]);
            int n = datax.length;
            int k = 10;
            CrossValidation cv = new CrossValidation(n, k);
            double rss = 0.0;
            for (int i = 0; i < k; ++i) {
                double[][] trainx = (double[][])Math.slice(datax, cv.train[i]);
                double[] trainy = Math.slice(datay, cv.train[i]);
                double[][] testx = (double[][])Math.slice(datax, cv.test[i]);
                double[] testy = Math.slice(datay, cv.test[i]);
                LASSO lasso = new LASSO(trainx, trainy, 50.0);
                for (int j = 0; j < testx.length; ++j) {
                    double r = testy[j] - lasso.predict(testx[j]);
                    rss += r * r;
                }
            }
            System.out.println("10-CV MSE = " + rss / (double)n);
        }
        catch (Exception ex) {
            System.err.println(ex);
        }
    }
}

