package smile.clustering;

import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.data.AttributeDataset;
import smile.data.NominalAttribute;
import smile.data.parser.DelimitedTextParser;
import smile.data.parser.IOUtils;
import smile.validation.AdjustedRandIndex;
import smile.validation.RandIndex;

/* loaded from: input_file:smile/clustering/SpectralClusteringTest.class */
public class SpectralClusteringTest {
    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testUSPS() {
        System.out.println("USPS");
        DelimitedTextParser delimitedTextParser = new DelimitedTextParser();
        delimitedTextParser.setResponseIndex(new NominalAttribute("class"), 0);
        try {
            AttributeDataset parse = delimitedTextParser.parse("USPS Train", IOUtils.getTestDataFile("usps/zip.train"));
            double[][] array = parse.toArray((Object[]) new double[parse.size()]);
            int[] array2 = parse.toArray(new int[parse.size()]);
            SpectralClustering spectralClustering = new SpectralClustering(array, 10, 8.0d);
            AdjustedRandIndex adjustedRandIndex = new AdjustedRandIndex();
            double measure = new RandIndex().measure(array2, spectralClustering.getClusterLabel());
            double measure2 = adjustedRandIndex.measure(array2, spectralClustering.getClusterLabel());
            System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", Double.valueOf(100.0d * measure), Double.valueOf(100.0d * measure2));
            Assert.assertTrue(measure > 0.85d);
            Assert.assertTrue(measure2 > 0.45d);
        } catch (Exception e) {
            System.err.println(e);
        }
    }

    @Test
    public void testUSPSNystrom() {
        System.out.println("USPS Nystrom approximation");
        DelimitedTextParser delimitedTextParser = new DelimitedTextParser();
        delimitedTextParser.setResponseIndex(new NominalAttribute("class"), 0);
        try {
            AttributeDataset parse = delimitedTextParser.parse("USPS Train", IOUtils.getTestDataFile("usps/zip.train"));
            double[][] array = parse.toArray((Object[]) new double[parse.size()]);
            int[] array2 = parse.toArray(new int[parse.size()]);
            SpectralClustering spectralClustering = new SpectralClustering(array, 10, 100, 8.0d);
            AdjustedRandIndex adjustedRandIndex = new AdjustedRandIndex();
            double measure = new RandIndex().measure(array2, spectralClustering.getClusterLabel());
            double measure2 = adjustedRandIndex.measure(array2, spectralClustering.getClusterLabel());
            System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", Double.valueOf(100.0d * measure), Double.valueOf(100.0d * measure2));
            Assert.assertTrue(measure > 0.8d);
            Assert.assertTrue(measure2 > 0.35d);
        } catch (Exception e) {
            System.err.println(e);
        }
    }
}
