/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.data.AttributeDataset;
import smile.data.parser.ArffParser;
import smile.data.parser.IOUtils;
import smile.math.Math;
import smile.regression.RLS;
import smile.validation.CrossValidation;

public class RLSTest {
    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    public void testOnlineLearn(String name, String fileName, int responseIndex) {
        System.out.println(name + "\t Online Learn");
        ArffParser parser = new ArffParser();
        parser.setResponseIndex(responseIndex);
        try {
            AttributeDataset data = parser.parse(IOUtils.getTestDataFile(fileName));
            double[][] datax = (double[][])data.toArray((E[])new double[data.size()][]);
            double[] datay = data.toArray(new double[data.size()]);
            int n = datax.length;
            int k = 10;
            CrossValidation cv = new CrossValidation(n, k);
            double rss = 0.0;
            for (int i = 0; i < k; ++i) {
                double[][] trainx = (double[][])Math.slice(datax, cv.train[i]);
                double[] trainy = Math.slice(datay, cv.train[i]);
                double[][] testx = (double[][])Math.slice(datax, cv.test[i]);
                double[] testy = Math.slice(datay, cv.test[i]);
                int l = trainx.length / 2;
                double[][] batchx = new double[l][];
                double[] batchy = new double[l];
                double[][] onlinex = new double[l][];
                double[] onliney = new double[l];
                for (int j = 0; j < l; ++j) {
                    batchx[j] = trainx[j];
                    batchy[j] = trainy[j];
                    onlinex[j] = trainx[l + j];
                    onliney[j] = trainy[l + j];
                }
                RLS rls = new RLS(batchx, batchy, 1.0);
                rls.learn(onlinex, onliney);
                for (int j = 0; j < testx.length; ++j) {
                    double r = testy[j] - rls.predict(testx[j]);
                    rss += r * r;
                }
            }
            System.out.println("MSE = " + rss / (double)n);
        }
        catch (Exception ex) {
            System.err.println(ex);
        }
    }

    @Test
    public void testOnlineLearn() {
        this.testOnlineLearn("CPU", "weka/cpu.arff", 6);
        this.testOnlineLearn("2dplanes", "weka/regression/2dplanes.arff", 10);
        this.testOnlineLearn("abalone", "weka/regression/abalone.arff", 8);
    }
}

