/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.data.AttributeDataset;
import smile.data.parser.ArffParser;
import smile.data.parser.IOUtils;
import smile.math.Math;
import smile.regression.NeuralNetwork;
import smile.validation.CrossValidation;

public class NeuralNetworkTest {
    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    public void test(NeuralNetwork.ActivationFunction activation, String dataset, String url, int response) {
        System.out.println(dataset + "\t" + (Object)((Object)activation));
        ArffParser parser = new ArffParser();
        parser.setResponseIndex(response);
        try {
            AttributeDataset data = parser.parse(IOUtils.getTestDataFile(url));
            double[] datay = data.toArray(new double[data.size()]);
            double[][] datax = (double[][])data.toArray((E[])new double[data.size()][]);
            int n = datax.length;
            int p = datax[0].length;
            double[] mux = Math.colMeans(datax);
            double[] sdx = Math.colSds(datax);
            double muy = Math.mean(datay);
            double sdy = Math.sd(datay);
            for (int i = 0; i < n; ++i) {
                datay[i] = (datay[i] - muy) / sdy;
                for (int j = 0; j < p; ++j) {
                    datax[i][j] = (datax[i][j] - mux[j]) / sdx[j];
                }
            }
            int k = 10;
            CrossValidation cv = new CrossValidation(n, k);
            double rss = 0.0;
            double ad = 0.0;
            for (int i = 0; i < k; ++i) {
                double[][] trainx = (double[][])Math.slice(datax, cv.train[i]);
                double[] trainy = Math.slice(datay, cv.train[i]);
                double[][] testx = (double[][])Math.slice(datax, cv.test[i]);
                double[] testy = Math.slice(datay, cv.test[i]);
                NeuralNetwork neuralNetwork = new NeuralNetwork(activation, datax[0].length, 10, 10, 1);
                neuralNetwork.learn(trainx, trainy);
                for (int j = 0; j < testx.length; ++j) {
                    double r = testy[j] - neuralNetwork.predict(testx[j]);
                    rss += r * r;
                }
            }
            System.out.format("10-CV MSE = %.4f%n", rss / (double)n);
        }
        catch (Exception ex) {
            System.err.println(ex);
        }
    }

    @Test
    public void testLogisticSigmoid() {
        this.test(NeuralNetwork.ActivationFunction.LOGISTIC_SIGMOID, "CPU", "weka/cpu.arff", 6);
        this.test(NeuralNetwork.ActivationFunction.LOGISTIC_SIGMOID, "abalone", "weka/regression/abalone.arff", 8);
        this.test(NeuralNetwork.ActivationFunction.LOGISTIC_SIGMOID, "cal_housing", "weka/regression/cal_housing.arff", 8);
        this.test(NeuralNetwork.ActivationFunction.LOGISTIC_SIGMOID, "kin8nm", "weka/regression/kin8nm.arff", 8);
    }

    @Test
    public void testTanh() {
        this.test(NeuralNetwork.ActivationFunction.TANH, "CPU", "weka/cpu.arff", 6);
        this.test(NeuralNetwork.ActivationFunction.TANH, "abalone", "weka/regression/abalone.arff", 8);
        this.test(NeuralNetwork.ActivationFunction.TANH, "cal_housing", "weka/regression/cal_housing.arff", 8);
        this.test(NeuralNetwork.ActivationFunction.TANH, "kin8nm", "weka/regression/kin8nm.arff", 8);
    }
}

