/*
 * Decompiled with CFR 0.152.
 */
package smile.sequence;

import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.math.Math;
import smile.sequence.HMM;
import smile.stat.distribution.EmpiricalDistribution;

public class HMMTest {
    double[] pi = new double[]{0.5, 0.5};
    double[][] a = new double[][]{{0.8, 0.2}, {0.2, 0.8}};
    double[][] b = new double[][]{{0.6, 0.4}, {0.4, 0.6}};

    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testNumStates() {
        System.out.println("numStates");
        HMM hmm = new HMM(this.pi, this.a, this.b);
        int expResult = 2;
        int result = hmm.numStates();
        Assert.assertEquals(expResult, result);
    }

    @Test
    public void testNumSymbols() {
        System.out.println("numSymbols");
        HMM hmm = new HMM(this.pi, this.a, this.b);
        int expResult = 2;
        int result = hmm.numSymbols();
        Assert.assertEquals(expResult, result);
    }

    @Test
    public void testGetInitialStateProbabilities() {
        System.out.println("getInitialStateProbabilities");
        HMM hmm = new HMM(this.pi, this.a, this.b);
        double[] expResult = this.pi;
        double[] result = hmm.getInitialStateProbabilities();
        for (int i = 0; i < expResult.length; ++i) {
            Assert.assertEquals(expResult[i], result[i], 1.0E-7);
        }
    }

    @Test
    public void testGetStateTransitionProbabilities() {
        System.out.println("getStateTransitionProbabilities");
        HMM hmm = new HMM(this.pi, this.a, this.b);
        double[][] expResult = this.a;
        double[][] result = hmm.getStateTransitionProbabilities();
        for (int i = 0; i < expResult.length; ++i) {
            for (int j = 0; j < expResult[i].length; ++j) {
                Assert.assertEquals(expResult[i][j], result[i][j], 1.0E-7);
            }
        }
    }

    @Test
    public void testGetSymbolEmissionProbabilities() {
        System.out.println("getSymbolEmissionProbabilities");
        HMM hmm = new HMM(this.pi, this.a, this.b);
        double[][] expResult = this.b;
        double[][] result = hmm.getSymbolEmissionProbabilities();
        for (int i = 0; i < expResult.length; ++i) {
            for (int j = 0; j < expResult[i].length; ++j) {
                Assert.assertEquals(expResult[i][j], result[i][j], 1.0E-7);
            }
        }
    }

    @Test
    public void testP_intArr_intArr() {
        System.out.println("p");
        int[] o = new int[]{0, 0, 1, 1, 0, 1, 1, 0};
        int[] s = new int[]{0, 0, 1, 1, 1, 1, 1, 0};
        HMM hmm = new HMM(this.pi, this.a, this.b);
        double expResult = 7.33836E-5;
        double result = hmm.p(o, s);
        Assert.assertEquals(expResult, result, 1.0E-10);
    }

    @Test
    public void testLogp_intArr_intArr() {
        System.out.println("logp");
        HMM hmm = new HMM(this.pi, this.a, this.b);
        int[] o = new int[]{0, 0, 1, 1, 0, 1, 1, 0};
        int[] s = new int[]{0, 0, 1, 1, 1, 1, 1, 0};
        double expResult = -9.51981;
        double result = hmm.logp(o, s);
        Assert.assertEquals(expResult, result, 1.0E-5);
    }

    @Test
    public void testP_intArr() {
        System.out.println("p");
        HMM hmm = new HMM(this.pi, this.a, this.b);
        int[] o = new int[]{0, 0, 1, 1, 0, 1, 1, 0};
        double expResult = 0.003663364;
        double result = hmm.p(o);
        Assert.assertEquals(expResult, result, 1.0E-9);
    }

    @Test
    public void testLogp_intArr() {
        System.out.println("logp");
        HMM hmm = new HMM(this.pi, this.a, this.b);
        int[] o = new int[]{0, 0, 1, 1, 0, 1, 1, 0};
        double expResult = -5.609373;
        double result = hmm.logp(o);
        Assert.assertEquals(expResult, result, 1.0E-6);
    }

    @Test
    public void testPredict() {
        System.out.println("predict");
        HMM hmm = new HMM(this.pi, this.a, this.b);
        int[] o = new int[]{0, 0, 1, 1, 0, 1, 1, 0};
        int[] s = new int[]{0, 0, 0, 0, 0, 0, 0, 0};
        int[] result = hmm.predict(o);
        Assert.assertEquals(o.length, result.length);
        for (int i = 0; i < s.length; ++i) {
            Assert.assertEquals(s[i], result[i]);
        }
    }

    @Test
    public void testLearn() {
        System.out.println("learn");
        EmpiricalDistribution initial = new EmpiricalDistribution(this.pi);
        EmpiricalDistribution[] transition = new EmpiricalDistribution[this.a.length];
        for (int i = 0; i < transition.length; ++i) {
            transition[i] = new EmpiricalDistribution(this.a[i]);
        }
        EmpiricalDistribution[] emission = new EmpiricalDistribution[this.b.length];
        for (int i = 0; i < emission.length; ++i) {
            emission[i] = new EmpiricalDistribution(this.b[i]);
        }
        int[][] sequences = new int[5000][];
        int[][] labels = new int[5000][];
        for (int i = 0; i < sequences.length; ++i) {
            sequences[i] = new int[30 * (Math.randomInt(5) + 1)];
            labels[i] = new int[sequences[i].length];
            int state = (int)initial.rand();
            sequences[i][0] = (int)emission[state].rand();
            labels[i][0] = state;
            for (int j = 1; j < sequences[i].length; ++j) {
                state = (int)transition[state].rand();
                sequences[i][j] = (int)emission[state].rand();
                labels[i][j] = state;
            }
        }
        HMM hmm = new HMM(sequences, (int[][])labels);
        System.out.println(hmm);
        double[] pi2 = new double[]{0.55, 0.45};
        double[][] a2 = new double[][]{{0.7, 0.3}, {0.15, 0.85}};
        double[][] b2 = new double[][]{{0.45, 0.55}, {0.3, 0.7}};
        HMM init = new HMM(pi2, a2, b2);
        HMM result = init.learn(sequences, 100);
        System.out.println(result);
    }

    @Test
    public void testP_intArr_intArr2() {
        System.out.println("p");
        String[] symbols = new String[]{"0", "1"};
        HMM<String> hmm = new HMM<String>(this.pi, this.a, this.b, symbols);
        String[] o = new String[]{"0", "0", "1", "1", "0", "1", "1", "0"};
        int[] s = new int[]{0, 0, 1, 1, 1, 1, 1, 0};
        double expResult = 7.33836E-5;
        double result = hmm.p((String[])o, s);
        Assert.assertEquals(expResult, result, 1.0E-10);
    }

    @Test
    public void testLogp_intArr_intArr2() {
        System.out.println("logp");
        String[] symbols = new String[]{"0", "1"};
        HMM<String> hmm = new HMM<String>(this.pi, this.a, this.b, symbols);
        String[] o = new String[]{"0", "0", "1", "1", "0", "1", "1", "0"};
        int[] s = new int[]{0, 0, 1, 1, 1, 1, 1, 0};
        double expResult = -9.51981;
        double result = hmm.logp((String[])o, s);
        Assert.assertEquals(expResult, result, 1.0E-5);
    }

    @Test
    public void testP_intArr2() {
        System.out.println("p");
        String[] symbols = new String[]{"0", "1"};
        HMM<String> hmm = new HMM<String>(this.pi, this.a, this.b, symbols);
        String[] o = new String[]{"0", "0", "1", "1", "0", "1", "1", "0"};
        double expResult = 0.003663364;
        double result = hmm.p((String[])o);
        Assert.assertEquals(expResult, result, 1.0E-9);
    }

    @Test
    public void testLogp_intArr2() {
        System.out.println("logp");
        String[] symbols = new String[]{"0", "1"};
        HMM<String> hmm = new HMM<String>(this.pi, this.a, this.b, symbols);
        String[] o = new String[]{"0", "0", "1", "1", "0", "1", "1", "0"};
        double expResult = -5.609373;
        double result = hmm.logp((String[])o);
        Assert.assertEquals(expResult, result, 1.0E-6);
    }

    @Test
    public void testPredict2() {
        System.out.println("predict");
        String[] symbols = new String[]{"0", "1"};
        HMM<String> hmm = new HMM<String>(this.pi, this.a, this.b, symbols);
        String[] o = new String[]{"0", "0", "1", "1", "0", "1", "1", "0"};
        int[] s = new int[]{0, 0, 0, 0, 0, 0, 0, 0};
        int[] result = hmm.predict((O[])o);
        Assert.assertEquals(o.length, result.length);
        for (int i = 0; i < s.length; ++i) {
            Assert.assertEquals(s[i], result[i]);
        }
    }

    @Test
    public void testPredict3() {
        System.out.println("predict");
        String[] symbols = new String[]{"H", "T", "P"};
        double[] pi2 = new double[]{0.4, 0.3, 0.3};
        double[][] a2 = new double[][]{{0.3, 0.4, 0.3}, {0.3, 0.3, 0.4}, {0.4, 0.2, 0.4}};
        double[][] b2 = new double[][]{{0.4, 0.3, 0.3}, {0.5, 0.2, 0.3}, {0.2, 0.3, 0.5}};
        HMM<String> hmm = new HMM<String>(pi2, a2, b2, symbols);
        String[] o = new String[]{"H", "H", "P", "P", "P", "H", "H", "H", "P", "P", "P", "H", "T", "T", "T"};
        int[] s = new int[]{0, 1, 2, 2, 2, 0, 1, 1, 2, 2, 2, 0, 2, 2, 0};
        int[] result = hmm.predict((O[])o);
        Assert.assertEquals(o.length, result.length);
        for (int i = 0; i < s.length; ++i) {
            Assert.assertEquals(s[i], result[i]);
        }
    }

    @Test
    public void testLearn2() {
        System.out.println("learn");
        EmpiricalDistribution initial = new EmpiricalDistribution(this.pi);
        EmpiricalDistribution[] transition = new EmpiricalDistribution[this.a.length];
        for (int i = 0; i < transition.length; ++i) {
            transition[i] = new EmpiricalDistribution(this.a[i]);
        }
        EmpiricalDistribution[] emission = new EmpiricalDistribution[this.b.length];
        for (int i = 0; i < emission.length; ++i) {
            emission[i] = new EmpiricalDistribution(this.b[i]);
        }
        String[] symbols = new String[]{"0", "1"};
        String[][] sequences = new String[5000][];
        int[][] labels = new int[5000][];
        for (int i = 0; i < sequences.length; ++i) {
            sequences[i] = new String[30 * (Math.randomInt(5) + 1)];
            labels[i] = new int[sequences[i].length];
            int state = (int)initial.rand();
            sequences[i][0] = symbols[(int)emission[state].rand()];
            labels[i][0] = state;
            for (int j = 1; j < sequences[i].length; ++j) {
                state = (int)transition[state].rand();
                sequences[i][j] = symbols[(int)emission[state].rand()];
                labels[i][j] = state;
            }
        }
        HMM<String> hmm = new HMM<String>(sequences, (int[][])labels);
        System.out.println(hmm);
        double[] pi2 = new double[]{0.55, 0.45};
        double[][] a2 = new double[][]{{0.7, 0.3}, {0.15, 0.85}};
        double[][] b2 = new double[][]{{0.45, 0.55}, {0.3, 0.7}};
        HMM<String> init = new HMM<String>(pi2, a2, b2, symbols);
        HMM<String> result = init.learn((String[][])sequences, 100);
        System.out.println(result);
    }
}

