package org.ddogleg.clustering.gmm;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.ddogleg.clustering.kmeans.InitializeKMeans_F64;
import org.ddogleg.clustering.kmeans.StandardKMeans_F64;
import org.ejml.data.DenseMatrix64F;
import org.ejml.equation.Equation;
import org.ejml.ops.CommonOps;
import org.ejml.ops.MatrixFeatures;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/ddogleg/clustering/gmm/TestSeedFromKMeans_F64.class */
public class TestSeedFromKMeans_F64 {
    Random rand = new Random(234);

    /* loaded from: input_file:org/ddogleg/clustering/gmm/TestSeedFromKMeans_F64$FixedSeeds.class */
    public static class FixedSeeds implements InitializeKMeans_F64 {
        @Override // org.ddogleg.clustering.kmeans.InitializeKMeans_F64
        public void init(int i, long j) {
        }

        @Override // org.ddogleg.clustering.kmeans.InitializeKMeans_F64
        public void selectSeeds(List<double[]> list, List<double[]> list2) {
            int length = list2.get(0).length;
            for (int i = 0; i < 2; i++) {
                System.arraycopy(list.get(i), 0, list2.get(i), 0, length);
            }
        }
    }

    @Test
    public void selectSeeds() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 1000; i++) {
            double[] dArr = {10.0d + (this.rand.nextGaussian() * 1.0d), 50.0d + (this.rand.nextGaussian() * 0.5d)};
            double[] dArr2 = {(-30.0d) + (this.rand.nextGaussian() * 1.0d), 20.0d + (this.rand.nextGaussian() * 0.5d)};
            arrayList.add(dArr);
            arrayList.add(dArr2);
        }
        SeedFromKMeans_F64 seedFromKMeans_F64 = new SeedFromKMeans_F64(createKMeans());
        seedFromKMeans_F64.init(2, 234234L);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(new GaussianGmm_F64(2));
        arrayList2.add(new GaussianGmm_F64(2));
        seedFromKMeans_F64.selectSeeds(arrayList, arrayList2);
        GaussianGmm_F64 gaussianGmm_F64 = arrayList2.get(0);
        GaussianGmm_F64 gaussianGmm_F642 = arrayList2.get(1);
        Assert.assertEquals(0.5d, gaussianGmm_F64.weight, 0.91d);
        Assert.assertEquals(0.5d, gaussianGmm_F642.weight, 0.01d);
        GaussianGmm_F64 computeGaussian = computeGaussian(0, arrayList);
        GaussianGmm_F64 computeGaussian2 = computeGaussian(1, arrayList);
        Assert.assertTrue(MatrixFeatures.isIdentical(computeGaussian.mean, gaussianGmm_F64.mean, 1.0E-8d));
        Assert.assertTrue(MatrixFeatures.isIdentical(computeGaussian2.mean, gaussianGmm_F642.mean, 1.0E-8d));
        Assert.assertTrue(MatrixFeatures.isIdentical(computeGaussian.covariance, gaussianGmm_F64.covariance, 1.0E-8d));
        Assert.assertTrue(MatrixFeatures.isIdentical(computeGaussian2.covariance, gaussianGmm_F642.covariance, 1.0E-8d));
    }

    private GaussianGmm_F64 computeGaussian(int i, List<double[]> list) {
        GaussianGmm_F64 gaussianGmm_F64 = new GaussianGmm_F64(2);
        for (int i2 = i; i2 < list.size(); i2 += 2) {
            double[] dArr = list.get(i2);
            double[] dArr2 = gaussianGmm_F64.mean.data;
            dArr2[0] = dArr2[0] + dArr[0];
            double[] dArr3 = gaussianGmm_F64.mean.data;
            dArr3[1] = dArr3[1] + dArr[1];
        }
        CommonOps.divide(gaussianGmm_F64.mean, list.size() / 2);
        Equation equation = new Equation();
        equation.alias(new Object[]{gaussianGmm_F64.mean, "mu", gaussianGmm_F64.covariance, "Q"});
        for (int i3 = i; i3 < list.size(); i3 += 2) {
            equation.alias(DenseMatrix64F.wrap(2, 1, list.get(i3)), "x");
            equation.process("Q = Q + (x-mu)*(x-mu)'");
        }
        CommonOps.divide(gaussianGmm_F64.covariance, (list.size() / 2) - 1);
        return gaussianGmm_F64;
    }

    private StandardKMeans_F64 createKMeans() {
        return new StandardKMeans_F64(200, 200, 1.0E-6d, new FixedSeeds());
    }
}
