/*
 * Decompiled with CFR 0.152.
 */
package org.ddogleg.clustering.gmm;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.ddogleg.clustering.gmm.GaussianGmm_F64;
import org.ddogleg.clustering.gmm.SeedFromKMeans_F64;
import org.ddogleg.clustering.kmeans.InitializeKMeans_F64;
import org.ddogleg.clustering.kmeans.StandardKMeans_F64;
import org.ejml.data.DenseMatrix64F;
import org.ejml.equation.Equation;
import org.ejml.ops.CommonOps;
import org.ejml.ops.MatrixFeatures;
import org.junit.Assert;
import org.junit.Test;

public class TestSeedFromKMeans_F64 {
    Random rand = new Random(234L);

    @Test
    public void selectSeeds() {
        double sigmaX = 1.0;
        double sigmaY = 0.5;
        double x0 = 10.0;
        double y0 = 50.0;
        double x1 = -30.0;
        double y1 = 20.0;
        ArrayList<double[]> points = new ArrayList<double[]>();
        for (int i = 0; i < 1000; ++i) {
            double[] a = new double[2];
            double[] b = new double[2];
            a[0] = x0 + this.rand.nextGaussian() * sigmaX;
            a[1] = y0 + this.rand.nextGaussian() * sigmaY;
            b[0] = x1 + this.rand.nextGaussian() * sigmaX;
            b[1] = y1 + this.rand.nextGaussian() * sigmaY;
            points.add(a);
            points.add(b);
        }
        SeedFromKMeans_F64 alg = new SeedFromKMeans_F64(this.createKMeans());
        alg.init(2, 234234L);
        ArrayList<GaussianGmm_F64> seeds = new ArrayList<GaussianGmm_F64>();
        seeds.add(new GaussianGmm_F64(2));
        seeds.add(new GaussianGmm_F64(2));
        alg.selectSeeds(points, seeds);
        GaussianGmm_F64 a = (GaussianGmm_F64)seeds.get(0);
        GaussianGmm_F64 b = (GaussianGmm_F64)seeds.get(1);
        Assert.assertEquals((double)0.5, (double)a.weight, (double)0.91);
        Assert.assertEquals((double)0.5, (double)b.weight, (double)0.01);
        GaussianGmm_F64 expectedA = this.computeGaussian(0, points);
        GaussianGmm_F64 expectedB = this.computeGaussian(1, points);
        Assert.assertTrue((boolean)MatrixFeatures.isIdentical(expectedA.mean, a.mean, 1.0E-8));
        Assert.assertTrue((boolean)MatrixFeatures.isIdentical(expectedB.mean, b.mean, 1.0E-8));
        Assert.assertTrue((boolean)MatrixFeatures.isIdentical(expectedA.covariance, a.covariance, 1.0E-8));
        Assert.assertTrue((boolean)MatrixFeatures.isIdentical(expectedB.covariance, b.covariance, 1.0E-8));
    }

    private GaussianGmm_F64 computeGaussian(int offset, List<double[]> points) {
        GaussianGmm_F64 out = new GaussianGmm_F64(2);
        for (int i = offset; i < points.size(); i += 2) {
            double[] p = points.get(i);
            out.mean.data[0] = out.mean.data[0] + p[0];
            out.mean.data[1] = out.mean.data[1] + p[1];
        }
        CommonOps.divide(out.mean, points.size() / 2);
        Equation eq = new Equation();
        eq.alias(out.mean, "mu", out.covariance, "Q");
        for (int i = offset; i < points.size(); i += 2) {
            double[] p = points.get(i);
            DenseMatrix64F x = DenseMatrix64F.wrap(2, 1, p);
            eq.alias(x, "x");
            eq.process("Q = Q + (x-mu)*(x-mu)'");
        }
        CommonOps.divide(out.covariance, points.size() / 2 - 1);
        return out;
    }

    private StandardKMeans_F64 createKMeans() {
        return new StandardKMeans_F64(200, 200, 1.0E-6, new FixedSeeds());
    }

    public static class FixedSeeds
    implements InitializeKMeans_F64 {
        @Override
        public void init(int pointDimension, long randomSeed) {
        }

        @Override
        public void selectSeeds(List<double[]> points, List<double[]> seeds) {
            int N = seeds.get(0).length;
            for (int i = 0; i < 2; ++i) {
                System.arraycopy(points.get(i), 0, seeds.get(i), 0, N);
            }
        }
    }
}

