/*
 * Decompiled with CFR 0.152.
 */
package org.ddogleg.clustering.gmm;

import java.util.List;
import java.util.Random;
import org.ddogleg.clustering.ComputeClusters;
import org.ddogleg.clustering.GenericClusterChecks_F64;
import org.ddogleg.clustering.gmm.ExpectationMaximizationGmm_F64;
import org.ddogleg.clustering.gmm.GaussianGmm_F64;
import org.ddogleg.clustering.gmm.SeedFromKMeans_F64;
import org.ddogleg.clustering.gmm.TestGaussianLikelihoodManager;
import org.ddogleg.clustering.kmeans.InitializeStandard_F64;
import org.ddogleg.clustering.kmeans.StandardKMeans_F64;
import org.ddogleg.clustering.kmeans.TestStandardKMeans_F64;
import org.ejml.data.DenseMatrix64F;
import org.ejml.equation.Equation;
import org.ejml.ops.CommonOps;
import org.ejml.ops.MatrixFeatures;
import org.junit.Assert;
import org.junit.Test;

public class TestExpectationMaximizationGmm_F64
extends GenericClusterChecks_F64 {
    Random rand = new Random(234L);
    StandardKMeans_F64 kmeans = new StandardKMeans_F64(1000, 1000, 1.0E-8, new TestStandardKMeans_F64.FixedSeeds());
    SeedFromKMeans_F64 seeds = new SeedFromKMeans_F64(this.kmeans);

    @Test
    public void expectation() {
        int j;
        int i;
        int DOF = 3;
        ExpectationMaximizationGmm_F64 alg = new ExpectationMaximizationGmm_F64(100, 1.0E-8, this.seeds);
        alg.init(DOF, 34535L);
        for (i = 0; i < 20; ++i) {
            ExpectationMaximizationGmm_F64.PointInfo p = alg.info.grow();
            p.point = new double[DOF];
            for (j = 0; j < DOF; ++j) {
                p.point[j] = this.rand.nextGaussian() * 5.0;
            }
        }
        for (i = 0; i < 3; ++i) {
            for (int j2 = 0; j2 < alg.info.size; ++j2) {
                ExpectationMaximizationGmm_F64.PointInfo p = alg.info.get(j2);
                p.weights.resize(i + 1);
            }
            GaussianGmm_F64 a = alg.mixture.grow();
            a.setMean(alg.info.get((int)i).point);
            a.weight = 2.0;
            CommonOps.setIdentity(a.covariance);
            alg.likelihoodManager.precomputeAll();
            alg.expectation();
            for (j = 0; j <= i; ++j) {
                ExpectationMaximizationGmm_F64.PointInfo p = alg.info.get(j);
                double expectedMax = p.weights.get(j);
                double total = 0.0;
                for (int k = 0; k <= i; ++k) {
                    double w = p.weights.get(k);
                    total += w;
                    if (k == j) continue;
                    Assert.assertTrue((w < expectedMax ? 1 : 0) != 0);
                }
                Assert.assertEquals((double)1.0, (double)total, (double)1.0E-8);
            }
        }
    }

    @Test
    public void maximization() {
        int DOF = 2;
        ExpectationMaximizationGmm_F64 alg = new ExpectationMaximizationGmm_F64(100, 1.0E-8, this.seeds);
        alg.init(DOF, 34535L);
        GaussianGmm_F64 a = alg.mixture.grow();
        a.setMean(new double[]{1.0, 0.5});
        CommonOps.diag(a.covariance, 2, 0.75, 1.0);
        a.weight = 0.25;
        GaussianGmm_F64 b = alg.mixture.grow();
        b.setMean(new double[]{4.0, 8.0});
        CommonOps.diag(b.covariance, 2, 0.5, 0.75);
        b.weight = 0.75;
        this.createPointsAround(1.0, 0.5, alg);
        this.createPointsAround(2.0, 3.0, alg);
        for (int i = 0; i < alg.mixture.size; ++i) {
            CommonOps.fill(alg.mixture.get((int)i).mean, 0.0);
            CommonOps.fill(alg.mixture.get((int)i).covariance, 0.0);
            alg.mixture.get((int)i).weight = 0.0;
        }
        alg.maximization();
        GaussianGmm_F64 expectedA = this.computeGaussian(0, alg.info.toList());
        GaussianGmm_F64 expectedB = this.computeGaussian(1, alg.info.toList());
        Assert.assertTrue((boolean)MatrixFeatures.isIdentical(expectedA.mean, a.mean, 1.0E-8));
        Assert.assertTrue((boolean)MatrixFeatures.isIdentical(expectedB.mean, b.mean, 1.0E-8));
        Assert.assertTrue((boolean)MatrixFeatures.isIdentical(expectedA.covariance, a.covariance, 1.0E-8));
        Assert.assertTrue((boolean)MatrixFeatures.isIdentical(expectedB.covariance, b.covariance, 1.0E-8));
    }

    private void createPointsAround(double cx, double cy, ExpectationMaximizationGmm_F64 alg) {
        for (int i = 0; i < 50; ++i) {
            for (int j = 0; j < 50; ++j) {
                int k;
                double x = cx + (double)i * 0.1 - 2.5;
                double y = cy + (double)i * 0.1 - 2.5;
                ExpectationMaximizationGmm_F64.PointInfo p = alg.info.grow();
                p.point = new double[]{x, y};
                p.weights.resize(2);
                double total = 0.0;
                for (k = 0; k < alg.mixture.size; ++k) {
                    p.weights.data[k] = TestGaussianLikelihoodManager.computeLikelihood(alg.mixture.get(k), p.point);
                    total += p.weights.data[k];
                }
                k = 0;
                while (k < alg.mixture.size) {
                    int n = k++;
                    p.weights.data[n] = p.weights.data[n] / total;
                }
            }
        }
    }

    private GaussianGmm_F64 computeGaussian(int which, List<ExpectationMaximizationGmm_F64.PointInfo> points) {
        int N = points.get((int)0).point.length;
        GaussianGmm_F64 out = new GaussianGmm_F64(N);
        double total = 0.0;
        for (int i = 0; i < points.size(); ++i) {
            ExpectationMaximizationGmm_F64.PointInfo p = points.get(i);
            double w = p.weights.data[which];
            total += w;
            for (int j = 0; j < N; ++j) {
                int n = j;
                out.mean.data[n] = out.mean.data[n] + w * p.point[j];
            }
        }
        CommonOps.divide(out.mean, total);
        Equation eq = new Equation();
        eq.alias(out.mean, "mu", out.covariance, "Q");
        for (int i = 0; i < points.size(); ++i) {
            ExpectationMaximizationGmm_F64.PointInfo p = points.get(i);
            double w = p.weights.data[which];
            DenseMatrix64F x = DenseMatrix64F.wrap(N, 1, p.point);
            eq.alias(x, "x", w, "w");
            eq.process("Q = Q + w*(x-mu)*(x-mu)'");
        }
        CommonOps.divide(out.covariance, total);
        return out;
    }

    @Override
    public ComputeClusters<double[]> createClustersAlg(boolean hint) {
        if (hint) {
            return new ExpectationMaximizationGmm_F64(1000, 1.0E-8, this.seeds);
        }
        InitializeStandard_F64 kseeds = new InitializeStandard_F64();
        StandardKMeans_F64 kmeans = new StandardKMeans_F64(1000, 1000, 1.0E-8, kseeds);
        SeedFromKMeans_F64 seeds = new SeedFromKMeans_F64(kmeans);
        return new ExpectationMaximizationGmm_F64(1000, 1.0E-8, seeds);
    }
}

