package smile.classification;

import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.classification.SVM;
import smile.data.AttributeDataset;
import smile.data.NominalAttribute;
import smile.data.parser.ArffParser;
import smile.data.parser.DelimitedTextParser;
import smile.data.parser.IOUtils;
import smile.math.Math;
import smile.math.kernel.GaussianKernel;
import smile.math.kernel.LinearKernel;
import smile.math.kernel.PolynomialKernel;

/* loaded from: input_file:smile/classification/SVMTest.class */
public class SVMTest {
    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testLinear() {
        SVM svm = new SVM(new LinearKernel(), 10.0d);
        svm.learn((SVM) new double[]{3.0d, 0.0d, 0.0d, 0.0d}, 1);
        svm.learn((SVM) new double[]{1.0d, 0.0d, 1.0d, 0.0d}, 1);
        svm.learn((SVM) new double[]{0.0d, 2.0d, 0.0d, 0.0d}, 0);
        svm.learn((SVM) new double[]{0.0d, 1.0d, 0.0d, 0.0d}, 0);
        svm.learn((SVM) new double[]{0.0d, 0.0d, 1.0d, 0.0d}, 1);
        svm.learn((SVM) new double[]{0.0d, 0.0d, 0.0d, 3.0d}, 0);
        svm.finish();
        System.out.println(svm.predict((SVM) new double[]{0.0d, 0.0d, 0.0d, 1.0d}));
    }

    @Test
    public void testLearn() {
        System.out.println("learn");
        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(4);
        try {
            AttributeDataset parse = arffParser.parse(IOUtils.getTestDataFile("weka/iris.arff"));
            double[][] array = parse.toArray((Object[]) new double[parse.size()]);
            int[] array2 = parse.toArray(new int[parse.size()]);
            SVM svm = new SVM(new LinearKernel(), 10.0d, Math.max(array2) + 1, SVM.Multiclass.ONE_VS_ALL);
            svm.learn(array, array2);
            svm.learn(array, array2);
            svm.finish();
            int i = 0;
            for (int i2 = 0; i2 < array.length; i2++) {
                if (svm.predict((SVM) array[i2]) != array2[i2]) {
                    i++;
                }
            }
            System.out.println("Linear ONE vs. ALL error = " + i);
            Assert.assertTrue(i <= 10);
            SVM svm2 = new SVM(new GaussianKernel(1.0d), 1.0d, Math.max(array2) + 1, SVM.Multiclass.ONE_VS_ALL);
            svm2.learn(array, array2);
            svm2.learn(array, array2);
            svm2.finish();
            svm2.trainPlattScaling(array, array2);
            int i3 = 0;
            for (int i4 = 0; i4 < array.length; i4++) {
                if (svm2.predict((SVM) array[i4]) != array2[i4]) {
                    i3++;
                }
                svm2.predict(array[i4], new double[3]);
            }
            System.out.println("Gaussian ONE vs. ALL error = " + i3);
            Assert.assertTrue(i3 <= 5);
            SVM svm3 = new SVM(new GaussianKernel(1.0d), 1.0d, Math.max(array2) + 1, SVM.Multiclass.ONE_VS_ONE);
            svm3.learn(array, array2);
            svm3.learn(array, array2);
            svm3.finish();
            Assert.assertTrue(!svm3.hasPlattScaling());
            svm3.trainPlattScaling(array, array2);
            Assert.assertTrue(svm3.hasPlattScaling());
            int i5 = 0;
            for (int i6 = 0; i6 < array.length; i6++) {
                if (svm3.predict((SVM) array[i6]) != array2[i6]) {
                    i5++;
                }
                svm3.predict(array[i6], new double[3]);
            }
            System.out.println("Gaussian ONE vs. ONE error = " + i5);
            Assert.assertTrue(i5 <= 5);
            SVM svm4 = new SVM(new PolynomialKernel(2), 1.0d, Math.max(array2) + 1, SVM.Multiclass.ONE_VS_ALL);
            svm4.learn(array, array2);
            svm4.learn(array, array2);
            svm4.finish();
            int i7 = 0;
            for (int i8 = 0; i8 < array.length; i8++) {
                if (svm4.predict((SVM) array[i8]) != array2[i8]) {
                    i7++;
                }
            }
            System.out.println("Polynomial ONE vs. ALL error = " + i7);
            Assert.assertTrue(i7 <= 5);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Test
    public void testSegment() {
        System.out.println("Segment");
        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(19);
        try {
            AttributeDataset parse = arffParser.parse(IOUtils.getTestDataFile("weka/segment-challenge.arff"));
            AttributeDataset parse2 = arffParser.parse(IOUtils.getTestDataFile("weka/segment-test.arff"));
            System.out.println(parse.size() + " " + parse2.size());
            double[][] array = parse.toArray((Object[]) new double[0]);
            int[] array2 = parse.toArray(new int[0]);
            double[][] array3 = parse2.toArray((Object[]) new double[0]);
            int[] array4 = parse2.toArray(new int[0]);
            SVM svm = new SVM(new GaussianKernel(8.0d), 5.0d, Math.max(array2) + 1, SVM.Multiclass.ONE_VS_ALL);
            svm.learn(array, array2);
            svm.finish();
            int i = 0;
            for (int i2 = 0; i2 < array3.length; i2++) {
                if (svm.predict((SVM) array3[i2]) != array4[i2]) {
                    i++;
                }
            }
            System.out.format("Segment error rate = %.2f%%%n", Double.valueOf((100.0d * i) / array3.length));
            Assert.assertTrue(i < 70);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Test
    public void testUSPS() {
        System.out.println("USPS");
        DelimitedTextParser delimitedTextParser = new DelimitedTextParser();
        delimitedTextParser.setResponseIndex(new NominalAttribute("class"), 0);
        try {
            AttributeDataset parse = delimitedTextParser.parse("USPS Train", IOUtils.getTestDataFile("usps/zip.train"));
            AttributeDataset parse2 = delimitedTextParser.parse("USPS Test", IOUtils.getTestDataFile("usps/zip.test"));
            double[][] array = parse.toArray((Object[]) new double[parse.size()]);
            int[] array2 = parse.toArray(new int[parse.size()]);
            double[][] array3 = parse2.toArray((Object[]) new double[parse2.size()]);
            int[] array4 = parse2.toArray(new int[parse2.size()]);
            SVM svm = new SVM(new GaussianKernel(8.0d), 5.0d, Math.max(array2) + 1, SVM.Multiclass.ONE_VS_ONE);
            svm.learn(array, array2);
            svm.finish();
            int i = 0;
            for (int i2 = 0; i2 < array3.length; i2++) {
                if (svm.predict((SVM) array3[i2]) != array4[i2]) {
                    i++;
                }
            }
            System.out.format("USPS error rate = %.2f%%%n", Double.valueOf((100.0d * i) / array3.length));
            Assert.assertTrue(i < 95);
            System.out.println("USPS one more epoch...");
            for (int i3 = 0; i3 < array.length; i3++) {
                int randomInt = Math.randomInt(array.length);
                svm.learn((SVM) array[randomInt], array2[randomInt]);
            }
            svm.finish();
            int i4 = 0;
            for (int i5 = 0; i5 < array3.length; i5++) {
                if (svm.predict((SVM) array3[i5]) != array4[i5]) {
                    i4++;
                }
            }
            System.out.format("USPS error rate = %.2f%%%n", Double.valueOf((100.0d * i4) / array3.length));
            Assert.assertTrue(i4 < 95);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
