package org.ddogleg.clustering.gmm;

import org.ddogleg.struct.FastQueue;
import org.ejml.data.DenseMatrix64F;
import org.ejml.equation.Equation;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/ddogleg/clustering/gmm/TestGaussianLikelihoodManager.class */
public class TestGaussianLikelihoodManager {
    @Test
    public void likelihood() {
        GaussianGmm_F64 gaussianGmm_F64 = new GaussianGmm_F64(3);
        GaussianGmm_F64 gaussianGmm_F642 = new GaussianGmm_F64(3);
        gaussianGmm_F64.mean.data = new double[]{5.0d, 3.0d, 5.0d};
        gaussianGmm_F642.mean.data = new double[]{-5.0d, 6.0d, -1.5d};
        gaussianGmm_F64.covariance.set(0, 0, 3.0d);
        gaussianGmm_F64.covariance.set(1, 1, 6.0d);
        gaussianGmm_F64.covariance.set(2, 2, 12.0d);
        gaussianGmm_F642.covariance.set(0, 0, 20.0d);
        gaussianGmm_F642.covariance.set(1, 1, 30.0d);
        gaussianGmm_F642.covariance.set(2, 2, 25.0d);
        FastQueue fastQueue = new FastQueue(GaussianGmm_F64.class, false);
        fastQueue.add(gaussianGmm_F64);
        fastQueue.add(gaussianGmm_F642);
        GaussianLikelihoodManager gaussianLikelihoodManager = new GaussianLikelihoodManager(3, fastQueue.toList());
        gaussianLikelihoodManager.precomputeAll();
        double[] dArr = {4.0d, 3.0d, -1.0d};
        double likelihood = gaussianLikelihoodManager.getLikelihood(0).likelihood(dArr);
        double chisq = gaussianLikelihoodManager.getLikelihood(0).getChisq();
        double likelihood2 = gaussianLikelihoodManager.getLikelihood(1).likelihood(dArr);
        double chisq2 = gaussianLikelihoodManager.getLikelihood(1).getChisq();
        Assert.assertEquals(5.0d, gaussianGmm_F64.mean.get(0, 0), 1.0E-8d);
        Assert.assertEquals(3.0d, gaussianGmm_F64.covariance.get(0, 0), 1.0E-8d);
        double computeLikelihood = computeLikelihood(gaussianGmm_F64, dArr);
        double computeLikelihood2 = computeLikelihood(gaussianGmm_F642, dArr);
        Assert.assertEquals(computeChiSq(gaussianGmm_F64, dArr), chisq, 1.0E-8d);
        Assert.assertEquals(computeChiSq(gaussianGmm_F642, dArr), chisq2, 1.0E-8d);
        Assert.assertEquals(likelihood / likelihood2, computeLikelihood / computeLikelihood2, 1.0E-8d);
    }

    public static double computeLikelihood(GaussianGmm_F64 gaussianGmm_F64, double[] dArr) {
        Equation equation = new Equation();
        equation.alias(new Object[]{gaussianGmm_F64.mean, "mu", gaussianGmm_F64.covariance, "S", Integer.valueOf(dArr.length), "D"});
        equation.alias(DenseMatrix64F.wrap(dArr.length, 1, dArr), "x");
        equation.process("left = 1.0/((2*pi)^(D/2.0)*sqrt(det(S)))");
        equation.process("likelihood = left*exp(-0.5*(x-mu)'*inv(S)*(x-mu))");
        return equation.lookupDouble("likelihood");
    }

    private double computeChiSq(GaussianGmm_F64 gaussianGmm_F64, double[] dArr) {
        Equation equation = new Equation();
        equation.alias(new Object[]{gaussianGmm_F64.mean, "mu", gaussianGmm_F64.covariance, "S"});
        equation.alias(DenseMatrix64F.wrap(dArr.length, 1, dArr), "x");
        equation.process("chisq = (x-mu)'*inv(S)*(x-mu)");
        return equation.lookupDouble("chisq");
    }
}
