/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.clustering.SpectralClustering;
import smile.data.AttributeDataset;
import smile.data.NominalAttribute;
import smile.data.parser.DelimitedTextParser;
import smile.data.parser.IOUtils;
import smile.validation.AdjustedRandIndex;
import smile.validation.RandIndex;

public class SpectralClusteringTest {
    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testUSPS() {
        System.out.println("USPS");
        DelimitedTextParser parser = new DelimitedTextParser();
        parser.setResponseIndex(new NominalAttribute("class"), 0);
        try {
            AttributeDataset train = parser.parse("USPS Train", IOUtils.getTestDataFile("usps/zip.train"));
            double[][] x = (double[][])train.toArray((E[])new double[train.size()][]);
            int[] y = train.toArray(new int[train.size()]);
            SpectralClustering spectral = new SpectralClustering(x, 10, 8.0);
            AdjustedRandIndex ari = new AdjustedRandIndex();
            RandIndex rand = new RandIndex();
            double r = rand.measure(y, spectral.getClusterLabel());
            double r2 = ari.measure(y, spectral.getClusterLabel());
            System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
            Assert.assertTrue(r > 0.85);
            Assert.assertTrue(r2 > 0.45);
        }
        catch (Exception ex) {
            System.err.println(ex);
        }
    }

    @Test
    public void testUSPSNystrom() {
        System.out.println("USPS Nystrom approximation");
        DelimitedTextParser parser = new DelimitedTextParser();
        parser.setResponseIndex(new NominalAttribute("class"), 0);
        try {
            AttributeDataset train = parser.parse("USPS Train", IOUtils.getTestDataFile("usps/zip.train"));
            double[][] x = (double[][])train.toArray((E[])new double[train.size()][]);
            int[] y = train.toArray(new int[train.size()]);
            SpectralClustering spectral = new SpectralClustering(x, 10, 100, 8.0);
            AdjustedRandIndex ari = new AdjustedRandIndex();
            RandIndex rand = new RandIndex();
            double r = rand.measure(y, spectral.getClusterLabel());
            double r2 = ari.measure(y, spectral.getClusterLabel());
            System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
            Assert.assertTrue(r > 0.8);
            Assert.assertTrue(r2 > 0.35);
        }
        catch (Exception ex) {
            System.err.println(ex);
        }
    }
}

