package smile.classification;

import java.io.IOException;
import java.text.ParseException;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.data.AttributeDataset;
import smile.data.NominalAttribute;
import smile.data.parser.ArffParser;
import smile.data.parser.DelimitedTextParser;
import smile.data.parser.IOUtils;

/* loaded from: input_file:smile/classification/KNNTest.class */
public class KNNTest {
    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testLearn_3args() {
        System.out.println("learn");
        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(4);
        try {
            AttributeDataset parse = arffParser.parse(IOUtils.getTestDataFile("weka/iris.arff"));
            double[][] array = parse.toArray((Object[]) new double[0]);
            int[] array2 = parse.toArray(new int[0]);
            KNN<double[]> learn = KNN.learn(array, array2, 1);
            int i = 0;
            for (int i2 = 0; i2 < array.length; i2++) {
                if (learn.predict((KNN<double[]>) array[i2]) != array2[i2]) {
                    i++;
                }
            }
            System.out.println("1-nn error = " + i);
            Assert.assertEquals(6L, i);
            KNN<double[]> learn2 = KNN.learn(array, array2, 3);
            int i3 = 0;
            for (int i4 = 0; i4 < array.length; i4++) {
                if (learn2.predict((KNN<double[]>) array[i4]) != array2[i4]) {
                    i3++;
                }
            }
            System.out.println("3-nn error = " + i3);
            Assert.assertEquals(6L, i3);
            KNN<double[]> learn3 = KNN.learn(array, array2, 5);
            int i5 = 0;
            for (int i6 = 0; i6 < array.length; i6++) {
                if (learn3.predict((KNN<double[]>) array[i6]) != array2[i6]) {
                    i5++;
                }
            }
            System.out.println("5-nn error = " + i5);
            Assert.assertEquals(5L, i5);
            KNN<double[]> learn4 = KNN.learn(array, array2, 7);
            int i7 = 0;
            for (int i8 = 0; i8 < array.length; i8++) {
                if (learn4.predict((KNN<double[]>) array[i8]) != array2[i8]) {
                    i7++;
                }
            }
            System.out.println("7-nn error = " + i7);
            Assert.assertEquals(5L, i7);
            KNN<double[]> learn5 = KNN.learn(array, array2, 9);
            int i9 = 0;
            for (int i10 = 0; i10 < array.length; i10++) {
                if (learn5.predict((KNN<double[]>) array[i10]) != array2[i10]) {
                    i9++;
                }
            }
            System.out.println("9-nn error = " + i9);
            Assert.assertEquals(5L, i9);
            KNN<double[]> learn6 = KNN.learn(array, array2, 11);
            int i11 = 0;
            for (int i12 = 0; i12 < array.length; i12++) {
                if (learn6.predict((KNN<double[]>) array[i12]) != array2[i12]) {
                    i11++;
                }
            }
            System.out.println("11-nn error = " + i11);
            Assert.assertEquals(4L, i11);
            KNN<double[]> learn7 = KNN.learn(array, array2, 13);
            int i13 = 0;
            for (int i14 = 0; i14 < array.length; i14++) {
                if (learn7.predict((KNN<double[]>) array[i14]) != array2[i14]) {
                    i13++;
                }
            }
            System.out.println("13-nn error = " + i13);
            Assert.assertEquals(5L, i13);
            KNN<double[]> learn8 = KNN.learn(array, array2, 15);
            int i15 = 0;
            for (int i16 = 0; i16 < array.length; i16++) {
                if (learn8.predict((KNN<double[]>) array[i16]) != array2[i16]) {
                    i15++;
                }
            }
            System.out.println("15-nn error = " + i15);
            Assert.assertEquals(4L, i15);
        } catch (Exception e) {
            System.err.println(e);
        }
    }

    @Test
    public void testSegment() throws ParseException {
        System.out.println("Segment");
        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(19);
        try {
            AttributeDataset parse = arffParser.parse(IOUtils.getTestDataFile("weka/segment-challenge.arff"));
            AttributeDataset parse2 = arffParser.parse(IOUtils.getTestDataFile("weka/segment-test.arff"));
            double[][] array = parse.toArray((Object[]) new double[0]);
            int[] array2 = parse.toArray(new int[0]);
            double[][] array3 = parse2.toArray((Object[]) new double[0]);
            int[] array4 = parse2.toArray(new int[0]);
            KNN<double[]> learn = KNN.learn(array, array2);
            int i = 0;
            for (int i2 = 0; i2 < array3.length; i2++) {
                if (learn.predict((KNN<double[]>) array3[i2]) != array4[i2]) {
                    i++;
                }
            }
            System.out.format("Segment error rate = %.2f%%%n", Double.valueOf((100.0d * i) / array3.length));
            Assert.assertEquals(39L, i);
        } catch (IOException e) {
            System.err.println(e);
        }
    }

    @Test
    public void testUSPS() {
        System.out.println("USPS");
        DelimitedTextParser delimitedTextParser = new DelimitedTextParser();
        delimitedTextParser.setResponseIndex(new NominalAttribute("class"), 0);
        try {
            AttributeDataset parse = delimitedTextParser.parse("USPS Train", IOUtils.getTestDataFile("usps/zip.train"));
            AttributeDataset parse2 = delimitedTextParser.parse("USPS Test", IOUtils.getTestDataFile("usps/zip.test"));
            double[][] array = parse.toArray((Object[]) new double[parse.size()]);
            int[] array2 = parse.toArray(new int[parse.size()]);
            double[][] array3 = parse2.toArray((Object[]) new double[parse2.size()]);
            int[] array4 = parse2.toArray(new int[parse2.size()]);
            KNN<double[]> learn = KNN.learn(array, array2);
            int i = 0;
            for (int i2 = 0; i2 < array3.length; i2++) {
                if (learn.predict((KNN<double[]>) array3[i2]) != array4[i2]) {
                    i++;
                }
            }
            System.out.format("USPS error rate = %.2f%%%n", Double.valueOf((100.0d * i) / array3.length));
            Assert.assertEquals(113L, i);
        } catch (Exception e) {
            System.err.println(e);
        }
    }
}
