/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.clustering.KMeans;
import smile.data.AttributeDataset;
import smile.data.NominalAttribute;
import smile.data.parser.DelimitedTextParser;
import smile.data.parser.IOUtils;
import smile.stat.distribution.MultivariateGaussianDistribution;
import smile.validation.AdjustedRandIndex;
import smile.validation.RandIndex;

public class KMeansTest {
    double[] mu1 = new double[]{1.0, 1.0, 1.0};
    double[][] sigma1 = new double[][]{{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}};
    double[] mu2 = new double[]{-2.0, -2.0, -2.0};
    double[][] sigma2 = new double[][]{{1.0, 0.3, 0.8}, {0.3, 1.0, 0.5}, {0.8, 0.5, 1.0}};
    double[] mu3 = new double[]{4.0, 2.0, 3.0};
    double[][] sigma3 = new double[][]{{1.0, 0.8, 0.3}, {0.8, 1.0, 0.5}, {0.3, 0.5, 1.0}};
    double[] mu4 = new double[]{3.0, 5.0, 1.0};
    double[][] sigma4 = new double[][]{{1.0, 0.5, 0.5}, {0.5, 1.0, 0.5}, {0.5, 0.5, 1.0}};
    double[][] data = new double[100000][];
    int[] label = new int[100000];

    public KMeansTest() {
        MultivariateGaussianDistribution g1 = new MultivariateGaussianDistribution(this.mu1, this.sigma1);
        for (int i = 0; i < 20000; ++i) {
            this.data[i] = g1.rand();
            this.label[i] = 0;
        }
        MultivariateGaussianDistribution g2 = new MultivariateGaussianDistribution(this.mu2, this.sigma2);
        for (int i = 0; i < 30000; ++i) {
            this.data[20000 + i] = g2.rand();
            this.label[i] = 1;
        }
        MultivariateGaussianDistribution g3 = new MultivariateGaussianDistribution(this.mu3, this.sigma3);
        for (int i = 0; i < 30000; ++i) {
            this.data[50000 + i] = g3.rand();
            this.label[i] = 2;
        }
        MultivariateGaussianDistribution g4 = new MultivariateGaussianDistribution(this.mu4, this.sigma4);
        for (int i = 0; i < 20000; ++i) {
            this.data[80000 + i] = g4.rand();
            this.label[i] = 3;
        }
    }

    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testBBD4() {
        System.out.println("BBD 4");
        KMeans kmeans = new KMeans(this.data, 4, 100);
        AdjustedRandIndex ari = new AdjustedRandIndex();
        RandIndex rand = new RandIndex();
        double r = rand.measure(this.label, kmeans.getClusterLabel());
        double r2 = ari.measure(this.label, kmeans.getClusterLabel());
        System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
    }

    @Test
    public void testLloyd4() {
        System.out.println("Lloyd 4");
        KMeans kmeans = KMeans.lloyd(this.data, 4, 100);
        AdjustedRandIndex ari = new AdjustedRandIndex();
        RandIndex rand = new RandIndex();
        double r = rand.measure(this.label, kmeans.getClusterLabel());
        double r2 = ari.measure(this.label, kmeans.getClusterLabel());
        System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
    }

    @Test
    public void testBBD64() {
        System.out.println("BBD 64");
        KMeans kmeans = new KMeans(this.data, 64, 100);
        AdjustedRandIndex ari = new AdjustedRandIndex();
        RandIndex rand = new RandIndex();
        double r = rand.measure(this.label, kmeans.getClusterLabel());
        double r2 = ari.measure(this.label, kmeans.getClusterLabel());
        System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
    }

    @Test
    public void testLloyd64() {
        System.out.println("Lloyd 64");
        KMeans kmeans = KMeans.lloyd(this.data, 64, 100);
        AdjustedRandIndex ari = new AdjustedRandIndex();
        RandIndex rand = new RandIndex();
        double r = rand.measure(this.label, kmeans.getClusterLabel());
        double r2 = ari.measure(this.label, kmeans.getClusterLabel());
        System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
    }

    @Test
    public void testUSPS() {
        System.out.println("USPS");
        DelimitedTextParser parser = new DelimitedTextParser();
        parser.setResponseIndex(new NominalAttribute("class"), 0);
        try {
            AttributeDataset train = parser.parse("USPS Train", IOUtils.getTestDataFile("usps/zip.train"));
            AttributeDataset test = parser.parse("USPS Test", IOUtils.getTestDataFile("usps/zip.test"));
            double[][] x = (double[][])train.toArray((E[])new double[train.size()][]);
            int[] y = train.toArray(new int[train.size()]);
            double[][] testx = (double[][])test.toArray((E[])new double[test.size()][]);
            int[] testy = test.toArray(new int[test.size()]);
            KMeans kmeans = new KMeans(x, 10, 100, 4);
            AdjustedRandIndex ari = new AdjustedRandIndex();
            RandIndex rand = new RandIndex();
            double r = rand.measure(y, kmeans.getClusterLabel());
            double r2 = ari.measure(y, kmeans.getClusterLabel());
            System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
            Assert.assertTrue(r > 0.85);
            Assert.assertTrue(r2 > 0.45);
            int[] p = new int[testx.length];
            for (int i = 0; i < testx.length; ++i) {
                p[i] = kmeans.predict(testx[i]);
            }
            r = rand.measure(testy, p);
            r2 = ari.measure(testy, p);
            System.out.format("Testing rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
            Assert.assertTrue(r > 0.85);
            Assert.assertTrue(r2 > 0.45);
        }
        catch (Exception ex) {
            System.err.println(ex);
        }
    }
}

