package smile.feature;

import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.classification.LDA;
import smile.data.AttributeDataset;
import smile.data.NominalAttribute;
import smile.data.parser.ArffParser;
import smile.data.parser.DelimitedTextParser;
import smile.data.parser.IOUtils;
import smile.sort.QuickSort;
import smile.validation.Accuracy;

/* loaded from: input_file:smile/feature/SumSquaresRatioTest.class */
public class SumSquaresRatioTest {
    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testRank() {
        System.out.println("rank");
        try {
            ArffParser arffParser = new ArffParser();
            arffParser.setResponseIndex(4);
            AttributeDataset parse = arffParser.parse(IOUtils.getTestDataFile("weka/iris.arff"));
            double[] rank = new SumSquaresRatio().rank(parse.toArray((Object[]) new double[parse.size()]), parse.toArray(new int[parse.size()]));
            Assert.assertEquals(4L, rank.length);
            Assert.assertEquals(1.6226463d, rank[0], 1.0E-7d);
            Assert.assertEquals(0.6444144d, rank[1], 1.0E-7d);
            Assert.assertEquals(16.0412833d, rank[2], 1.0E-7d);
            Assert.assertEquals(13.0520327d, rank[3], 1.0E-7d);
        } catch (Exception e) {
            System.err.println(e);
        }
    }

    @Test
    public void testLearn() {
        System.out.println("USPS");
        try {
            DelimitedTextParser delimitedTextParser = new DelimitedTextParser();
            delimitedTextParser.setResponseIndex(new NominalAttribute("class"), 0);
            AttributeDataset parse = delimitedTextParser.parse("USPS Train", IOUtils.getTestDataFile("usps/zip.train"));
            AttributeDataset parse2 = delimitedTextParser.parse("USPS Test", IOUtils.getTestDataFile("usps/zip.test"));
            double[][] array = parse.toArray((Object[]) new double[parse.size()]);
            int[] array2 = parse.toArray(new int[parse.size()]);
            double[][] array3 = parse2.toArray((Object[]) new double[parse2.size()]);
            int[] array4 = parse2.toArray(new int[parse2.size()]);
            int[] sort = QuickSort.sort(new SumSquaresRatio().rank(array, array2));
            int length = array.length;
            double[][] dArr = new double[length][135];
            for (int i = 0; i < 135; i++) {
                for (int i2 = 0; i2 < length; i2++) {
                    dArr[i2][i] = array[i2][sort[255 - i]];
                }
            }
            int length2 = array3.length;
            double[][] dArr2 = new double[length2][135];
            for (int i3 = 0; i3 < 135; i3++) {
                for (int i4 = 0; i4 < length2; i4++) {
                    dArr2[i4][i3] = array3[i4][sort[255 - i3]];
                }
            }
            LDA lda = new LDA(dArr, array2);
            int[] iArr = new int[length2];
            for (int i5 = 0; i5 < length2; i5++) {
                iArr[i5] = lda.predict(dArr2[i5]);
            }
            System.out.format("SSR %.2f%%%n", Double.valueOf(100.0d * new Accuracy().measure(array4, iArr)));
        } catch (Exception e) {
            System.err.println(e);
        }
    }
}
