package smile.clustering;

import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.data.AttributeDataset;
import smile.data.NominalAttribute;
import smile.data.parser.DelimitedTextParser;
import smile.data.parser.IOUtils;
import smile.stat.distribution.MultivariateGaussianDistribution;
import smile.validation.AdjustedRandIndex;
import smile.validation.RandIndex;

/* loaded from: input_file:smile/clustering/KMeansTest.class */
public class KMeansTest {
    double[] mu1 = {1.0d, 1.0d, 1.0d};
    double[][] sigma1 = {new double[]{1.0d, 0.0d, 0.0d}, new double[]{0.0d, 1.0d, 0.0d}, new double[]{0.0d, 0.0d, 1.0d}};
    double[] mu2 = {-2.0d, -2.0d, -2.0d};
    double[][] sigma2 = {new double[]{1.0d, 0.3d, 0.8d}, new double[]{0.3d, 1.0d, 0.5d}, new double[]{0.8d, 0.5d, 1.0d}};
    double[] mu3 = {4.0d, 2.0d, 3.0d};
    double[][] sigma3 = {new double[]{1.0d, 0.8d, 0.3d}, new double[]{0.8d, 1.0d, 0.5d}, new double[]{0.3d, 0.5d, 1.0d}};
    double[] mu4 = {3.0d, 5.0d, 1.0d};
    double[][] sigma4 = {new double[]{1.0d, 0.5d, 0.5d}, new double[]{0.5d, 1.0d, 0.5d}, new double[]{0.5d, 0.5d, 1.0d}};
    double[][] data = new double[100000];
    int[] label = new int[100000];

    /* JADX WARN: Type inference failed for: r1v11, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v15, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v17, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v3, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v7, types: [double[], double[][]] */
    public KMeansTest() {
        MultivariateGaussianDistribution multivariateGaussianDistribution = new MultivariateGaussianDistribution(this.mu1, this.sigma1);
        for (int i = 0; i < 20000; i++) {
            this.data[i] = multivariateGaussianDistribution.rand();
            this.label[i] = 0;
        }
        MultivariateGaussianDistribution multivariateGaussianDistribution2 = new MultivariateGaussianDistribution(this.mu2, this.sigma2);
        for (int i2 = 0; i2 < 30000; i2++) {
            this.data[20000 + i2] = multivariateGaussianDistribution2.rand();
            this.label[i2] = 1;
        }
        MultivariateGaussianDistribution multivariateGaussianDistribution3 = new MultivariateGaussianDistribution(this.mu3, this.sigma3);
        for (int i3 = 0; i3 < 30000; i3++) {
            this.data[50000 + i3] = multivariateGaussianDistribution3.rand();
            this.label[i3] = 2;
        }
        MultivariateGaussianDistribution multivariateGaussianDistribution4 = new MultivariateGaussianDistribution(this.mu4, this.sigma4);
        for (int i4 = 0; i4 < 20000; i4++) {
            this.data[80000 + i4] = multivariateGaussianDistribution4.rand();
            this.label[i4] = 3;
        }
    }

    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testBBD4() {
        System.out.println("BBD 4");
        KMeans kMeans = new KMeans(this.data, 4, 100);
        AdjustedRandIndex adjustedRandIndex = new AdjustedRandIndex();
        System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", Double.valueOf(100.0d * new RandIndex().measure(this.label, kMeans.getClusterLabel())), Double.valueOf(100.0d * adjustedRandIndex.measure(this.label, kMeans.getClusterLabel())));
    }

    @Test
    public void testLloyd4() {
        System.out.println("Lloyd 4");
        KMeans lloyd = KMeans.lloyd(this.data, 4, 100);
        AdjustedRandIndex adjustedRandIndex = new AdjustedRandIndex();
        System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", Double.valueOf(100.0d * new RandIndex().measure(this.label, lloyd.getClusterLabel())), Double.valueOf(100.0d * adjustedRandIndex.measure(this.label, lloyd.getClusterLabel())));
    }

    @Test
    public void testBBD64() {
        System.out.println("BBD 64");
        KMeans kMeans = new KMeans(this.data, 64, 100);
        AdjustedRandIndex adjustedRandIndex = new AdjustedRandIndex();
        System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", Double.valueOf(100.0d * new RandIndex().measure(this.label, kMeans.getClusterLabel())), Double.valueOf(100.0d * adjustedRandIndex.measure(this.label, kMeans.getClusterLabel())));
    }

    @Test
    public void testLloyd64() {
        System.out.println("Lloyd 64");
        KMeans lloyd = KMeans.lloyd(this.data, 64, 100);
        AdjustedRandIndex adjustedRandIndex = new AdjustedRandIndex();
        System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", Double.valueOf(100.0d * new RandIndex().measure(this.label, lloyd.getClusterLabel())), Double.valueOf(100.0d * adjustedRandIndex.measure(this.label, lloyd.getClusterLabel())));
    }

    @Test
    public void testUSPS() {
        System.out.println("USPS");
        DelimitedTextParser delimitedTextParser = new DelimitedTextParser();
        delimitedTextParser.setResponseIndex(new NominalAttribute("class"), 0);
        try {
            AttributeDataset parse = delimitedTextParser.parse("USPS Train", IOUtils.getTestDataFile("usps/zip.train"));
            AttributeDataset parse2 = delimitedTextParser.parse("USPS Test", IOUtils.getTestDataFile("usps/zip.test"));
            double[][] array = parse.toArray((Object[]) new double[parse.size()]);
            int[] array2 = parse.toArray(new int[parse.size()]);
            double[][] array3 = parse2.toArray((Object[]) new double[parse2.size()]);
            int[] array4 = parse2.toArray(new int[parse2.size()]);
            KMeans kMeans = new KMeans(array, 10, 100, 4);
            AdjustedRandIndex adjustedRandIndex = new AdjustedRandIndex();
            RandIndex randIndex = new RandIndex();
            double measure = randIndex.measure(array2, kMeans.getClusterLabel());
            double measure2 = adjustedRandIndex.measure(array2, kMeans.getClusterLabel());
            System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", Double.valueOf(100.0d * measure), Double.valueOf(100.0d * measure2));
            Assert.assertTrue(measure > 0.85d);
            Assert.assertTrue(measure2 > 0.45d);
            int[] iArr = new int[array3.length];
            for (int i = 0; i < array3.length; i++) {
                iArr[i] = kMeans.predict(array3[i]);
            }
            double measure3 = randIndex.measure(array4, iArr);
            double measure4 = adjustedRandIndex.measure(array4, iArr);
            System.out.format("Testing rand index = %.2f%%\tadjusted rand index = %.2f%%%n", Double.valueOf(100.0d * measure3), Double.valueOf(100.0d * measure4));
            Assert.assertTrue(measure3 > 0.85d);
            Assert.assertTrue(measure4 > 0.45d);
        } catch (Exception e) {
            System.err.println(e);
        }
    }
}
