package org.ddogleg.clustering.gmm;

import com.mhuss.AstroLib.Astro;
import java.util.List;
import java.util.Random;
import org.ddogleg.clustering.ComputeClusters;
import org.ddogleg.clustering.GenericClusterChecks_F64;
import org.ddogleg.clustering.gmm.ExpectationMaximizationGmm_F64;
import org.ddogleg.clustering.kmeans.InitializeStandard_F64;
import org.ddogleg.clustering.kmeans.StandardKMeans_F64;
import org.ddogleg.clustering.kmeans.TestStandardKMeans_F64;
import org.ejml.data.DenseMatrix64F;
import org.ejml.equation.Equation;
import org.ejml.ops.CommonOps;
import org.ejml.ops.MatrixFeatures;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/ddogleg/clustering/gmm/TestExpectationMaximizationGmm_F64.class */
public class TestExpectationMaximizationGmm_F64 extends GenericClusterChecks_F64 {
    Random rand = new Random(234);
    StandardKMeans_F64 kmeans = new StandardKMeans_F64(Astro.MILLISECONDS_PER_SECOND, Astro.MILLISECONDS_PER_SECOND, 1.0E-8d, new TestStandardKMeans_F64.FixedSeeds());
    SeedFromKMeans_F64 seeds = new SeedFromKMeans_F64(this.kmeans);

    @Test
    public void expectation() {
        ExpectationMaximizationGmm_F64 expectationMaximizationGmm_F64 = new ExpectationMaximizationGmm_F64(100, 1.0E-8d, this.seeds);
        expectationMaximizationGmm_F64.init(3, 34535L);
        for (int i = 0; i < 20; i++) {
            ExpectationMaximizationGmm_F64.PointInfo grow = expectationMaximizationGmm_F64.info.grow();
            grow.point = new double[3];
            for (int i2 = 0; i2 < 3; i2++) {
                grow.point[i2] = this.rand.nextGaussian() * 5.0d;
            }
        }
        for (int i3 = 0; i3 < 3; i3++) {
            for (int i4 = 0; i4 < expectationMaximizationGmm_F64.info.size; i4++) {
                expectationMaximizationGmm_F64.info.get(i4).weights.resize(i3 + 1);
            }
            GaussianGmm_F64 grow2 = expectationMaximizationGmm_F64.mixture.grow();
            grow2.setMean(expectationMaximizationGmm_F64.info.get(i3).point);
            grow2.weight = 2.0d;
            CommonOps.setIdentity(grow2.covariance);
            expectationMaximizationGmm_F64.likelihoodManager.precomputeAll();
            expectationMaximizationGmm_F64.expectation();
            for (int i5 = 0; i5 <= i3; i5++) {
                ExpectationMaximizationGmm_F64.PointInfo pointInfo = expectationMaximizationGmm_F64.info.get(i5);
                double d = pointInfo.weights.get(i5);
                double d2 = 0.0d;
                for (int i6 = 0; i6 <= i3; i6++) {
                    double d3 = pointInfo.weights.get(i6);
                    d2 += d3;
                    if (i6 != i5) {
                        Assert.assertTrue(d3 < d);
                    }
                }
                Assert.assertEquals(1.0d, d2, 1.0E-8d);
            }
        }
    }

    @Test
    public void maximization() {
        ExpectationMaximizationGmm_F64 expectationMaximizationGmm_F64 = new ExpectationMaximizationGmm_F64(100, 1.0E-8d, this.seeds);
        expectationMaximizationGmm_F64.init(2, 34535L);
        GaussianGmm_F64 grow = expectationMaximizationGmm_F64.mixture.grow();
        grow.setMean(new double[]{1.0d, 0.5d});
        CommonOps.diag(grow.covariance, 2, new double[]{0.75d, 1.0d});
        grow.weight = 0.25d;
        GaussianGmm_F64 grow2 = expectationMaximizationGmm_F64.mixture.grow();
        grow2.setMean(new double[]{4.0d, 8.0d});
        CommonOps.diag(grow2.covariance, 2, new double[]{0.5d, 0.75d});
        grow2.weight = 0.75d;
        createPointsAround(1.0d, 0.5d, expectationMaximizationGmm_F64);
        createPointsAround(2.0d, 3.0d, expectationMaximizationGmm_F64);
        for (int i = 0; i < expectationMaximizationGmm_F64.mixture.size; i++) {
            CommonOps.fill(expectationMaximizationGmm_F64.mixture.get(i).mean, 0.0d);
            CommonOps.fill(expectationMaximizationGmm_F64.mixture.get(i).covariance, 0.0d);
            expectationMaximizationGmm_F64.mixture.get(i).weight = 0.0d;
        }
        expectationMaximizationGmm_F64.maximization();
        GaussianGmm_F64 computeGaussian = computeGaussian(0, expectationMaximizationGmm_F64.info.toList());
        GaussianGmm_F64 computeGaussian2 = computeGaussian(1, expectationMaximizationGmm_F64.info.toList());
        Assert.assertTrue(MatrixFeatures.isIdentical(computeGaussian.mean, grow.mean, 1.0E-8d));
        Assert.assertTrue(MatrixFeatures.isIdentical(computeGaussian2.mean, grow2.mean, 1.0E-8d));
        Assert.assertTrue(MatrixFeatures.isIdentical(computeGaussian.covariance, grow.covariance, 1.0E-8d));
        Assert.assertTrue(MatrixFeatures.isIdentical(computeGaussian2.covariance, grow2.covariance, 1.0E-8d));
    }

    private void createPointsAround(double d, double d2, ExpectationMaximizationGmm_F64 expectationMaximizationGmm_F64) {
        for (int i = 0; i < 50; i++) {
            for (int i2 = 0; i2 < 50; i2++) {
                ExpectationMaximizationGmm_F64.PointInfo grow = expectationMaximizationGmm_F64.info.grow();
                grow.point = new double[]{(d + (i * 0.1d)) - 2.5d, (d2 + (i * 0.1d)) - 2.5d};
                grow.weights.resize(2);
                double d3 = 0.0d;
                for (int i3 = 0; i3 < expectationMaximizationGmm_F64.mixture.size; i3++) {
                    double computeLikelihood = TestGaussianLikelihoodManager.computeLikelihood(expectationMaximizationGmm_F64.mixture.get(i3), grow.point);
                    grow.weights.data[i3] = computeLikelihood;
                    d3 += computeLikelihood;
                }
                for (int i4 = 0; i4 < expectationMaximizationGmm_F64.mixture.size; i4++) {
                    double[] dArr = grow.weights.data;
                    int i5 = i4;
                    dArr[i5] = dArr[i5] / d3;
                }
            }
        }
    }

    private GaussianGmm_F64 computeGaussian(int i, List<ExpectationMaximizationGmm_F64.PointInfo> list) {
        int length = list.get(0).point.length;
        GaussianGmm_F64 gaussianGmm_F64 = new GaussianGmm_F64(length);
        double d = 0.0d;
        for (int i2 = 0; i2 < list.size(); i2++) {
            ExpectationMaximizationGmm_F64.PointInfo pointInfo = list.get(i2);
            double d2 = pointInfo.weights.data[i];
            d += d2;
            for (int i3 = 0; i3 < length; i3++) {
                double[] dArr = gaussianGmm_F64.mean.data;
                int i4 = i3;
                dArr[i4] = dArr[i4] + (d2 * pointInfo.point[i3]);
            }
        }
        CommonOps.divide(gaussianGmm_F64.mean, d);
        Equation equation = new Equation();
        equation.alias(new Object[]{gaussianGmm_F64.mean, "mu", gaussianGmm_F64.covariance, "Q"});
        for (int i5 = 0; i5 < list.size(); i5++) {
            ExpectationMaximizationGmm_F64.PointInfo pointInfo2 = list.get(i5);
            equation.alias(new Object[]{DenseMatrix64F.wrap(length, 1, pointInfo2.point), "x", Double.valueOf(pointInfo2.weights.data[i]), "w"});
            equation.process("Q = Q + w*(x-mu)*(x-mu)'");
        }
        CommonOps.divide(gaussianGmm_F64.covariance, d);
        return gaussianGmm_F64;
    }

    @Override // org.ddogleg.clustering.GenericClusterChecks_F64
    public ComputeClusters<double[]> createClustersAlg(boolean z) {
        return z ? new ExpectationMaximizationGmm_F64(Astro.MILLISECONDS_PER_SECOND, 1.0E-8d, this.seeds) : new ExpectationMaximizationGmm_F64(Astro.MILLISECONDS_PER_SECOND, 1.0E-8d, new SeedFromKMeans_F64(new StandardKMeans_F64(Astro.MILLISECONDS_PER_SECOND, Astro.MILLISECONDS_PER_SECOND, 1.0E-8d, new InitializeStandard_F64())));
    }
}
