package smile.regression;

import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.data.AttributeDataset;
import smile.data.parser.ArffParser;
import smile.data.parser.IOUtils;
import smile.math.Math;
import smile.regression.NeuralNetwork;
import smile.validation.CrossValidation;

/* loaded from: input_file:smile/regression/NeuralNetworkTest.class */
public class NeuralNetworkTest {
    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    public void test(NeuralNetwork.ActivationFunction activationFunction, String str, String str2, int i) {
        System.out.println(str + "\t" + activationFunction);
        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(i);
        try {
            AttributeDataset parse = arffParser.parse(IOUtils.getTestDataFile(str2));
            double[] array = parse.toArray(new double[parse.size()]);
            double[][] array2 = parse.toArray((Object[]) new double[parse.size()]);
            int length = array2.length;
            int length2 = array2[0].length;
            double[] colMeans = Math.colMeans(array2);
            double[] colSds = Math.colSds(array2);
            double mean = Math.mean(array);
            double sd = Math.sd(array);
            for (int i2 = 0; i2 < length; i2++) {
                array[i2] = (array[i2] - mean) / sd;
                for (int i3 = 0; i3 < length2; i3++) {
                    array2[i2][i3] = (array2[i2][i3] - colMeans[i3]) / colSds[i3];
                }
            }
            CrossValidation crossValidation = new CrossValidation(length, 10);
            double d = 0.0d;
            for (int i4 = 0; i4 < 10; i4++) {
                double[][] dArr = (double[][]) Math.slice(array2, crossValidation.train[i4]);
                double[] slice = Math.slice(array, crossValidation.train[i4]);
                double[][] dArr2 = (double[][]) Math.slice(array2, crossValidation.test[i4]);
                double[] slice2 = Math.slice(array, crossValidation.test[i4]);
                NeuralNetwork neuralNetwork = new NeuralNetwork(activationFunction, array2[0].length, 10, 10, 1);
                neuralNetwork.learn(dArr, slice);
                for (int i5 = 0; i5 < dArr2.length; i5++) {
                    double predict = slice2[i5] - neuralNetwork.predict(dArr2[i5]);
                    d += predict * predict;
                }
            }
            System.out.format("10-CV MSE = %.4f%n", Double.valueOf(d / length));
        } catch (Exception e) {
            System.err.println(e);
        }
    }

    @Test
    public void testLogisticSigmoid() {
        test(NeuralNetwork.ActivationFunction.LOGISTIC_SIGMOID, "CPU", "weka/cpu.arff", 6);
        test(NeuralNetwork.ActivationFunction.LOGISTIC_SIGMOID, "abalone", "weka/regression/abalone.arff", 8);
        test(NeuralNetwork.ActivationFunction.LOGISTIC_SIGMOID, "cal_housing", "weka/regression/cal_housing.arff", 8);
        test(NeuralNetwork.ActivationFunction.LOGISTIC_SIGMOID, "kin8nm", "weka/regression/kin8nm.arff", 8);
    }

    @Test
    public void testTanh() {
        test(NeuralNetwork.ActivationFunction.TANH, "CPU", "weka/cpu.arff", 6);
        test(NeuralNetwork.ActivationFunction.TANH, "abalone", "weka/regression/abalone.arff", 8);
        test(NeuralNetwork.ActivationFunction.TANH, "cal_housing", "weka/regression/cal_housing.arff", 8);
        test(NeuralNetwork.ActivationFunction.TANH, "kin8nm", "weka/regression/kin8nm.arff", 8);
    }
}
