package smile.regression;

import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.data.AttributeDataset;
import smile.data.parser.ArffParser;
import smile.data.parser.IOUtils;
import smile.math.Math;
import smile.math.kernel.PolynomialKernel;
import smile.validation.CrossValidation;

/* loaded from: input_file:smile/regression/SVRTest.class */
public class SVRTest {
    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testCPU() {
        System.out.println("CPU");
        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(6);
        try {
            AttributeDataset parse = arffParser.parse(IOUtils.getTestDataFile("weka/cpu.arff"));
            double[] array = parse.toArray(new double[parse.size()]);
            double[][] array2 = parse.toArray((Object[]) new double[parse.size()]);
            Math.standardize(array2);
            int length = array2.length;
            CrossValidation crossValidation = new CrossValidation(length, 10);
            double d = 0.0d;
            for (int i = 0; i < 10; i++) {
                double[][] dArr = (double[][]) Math.slice(array2, crossValidation.train[i]);
                double[] slice = Math.slice(array, crossValidation.train[i]);
                double[][] dArr2 = (double[][]) Math.slice(array2, crossValidation.test[i]);
                double[] slice2 = Math.slice(array, crossValidation.test[i]);
                SVR svr = new SVR(dArr, slice, new PolynomialKernel(3, 1.0d, 1.0d), 0.1d, 1.0d);
                for (int i2 = 0; i2 < dArr2.length; i2++) {
                    double predict = slice2[i2] - svr.predict((SVR) dArr2[i2]);
                    d += predict * predict;
                }
            }
            System.out.println("10-CV RMSE = " + Math.sqrt(d / length));
        } catch (Exception e) {
            System.err.println(e);
        }
    }
}
