/*
 * Decompiled with CFR 0.152.
 */
package smile.sequence;

import java.io.BufferedReader;
import java.io.IOException;
import java.text.ParseException;
import java.util.ArrayList;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.data.Attribute;
import smile.data.NominalAttribute;
import smile.data.parser.IOUtils;
import smile.math.Math;
import smile.sequence.CRF;

public class CRFTest {
    IntDataset load(String resource) {
        int p = 0;
        int k = 0;
        IntDataset dataset = new IntDataset();
        ArrayList<int[][]> x = new ArrayList<int[][]>();
        ArrayList<int[]> y = new ArrayList<int[]>();
        ArrayList<int[]> seq = new ArrayList<int[]>();
        ArrayList<Integer> label = new ArrayList<Integer>();
        int id = 1;
        try (BufferedReader input = IOUtils.getTestDataReader(resource);){
            String[] words = input.readLine().split(" ");
            int nseq = Integer.parseInt(words[0]);
            k = Integer.parseInt(words[1]);
            p = Integer.parseInt(words[2]);
            String line = null;
            while ((line = input.readLine()) != null) {
                words = line.split(" ");
                int seqid = Integer.parseInt(words[0]);
                int pos = Integer.parseInt(words[1]);
                int len = Integer.parseInt(words[2]);
                int[] feature = new int[len];
                for (int i = 0; i < len; ++i) {
                    try {
                        feature[i] = Integer.parseInt(words[i + 3]);
                        continue;
                    }
                    catch (Exception ex) {
                        System.err.println(ex);
                    }
                }
                if (seqid == id) {
                    seq.add(feature);
                    label.add(Integer.valueOf(words[len + 3]));
                    continue;
                }
                id = seqid;
                int[][] xx = new int[seq.size()][];
                int[] yy = new int[seq.size()];
                for (int i = 0; i < seq.size(); ++i) {
                    xx[i] = (int[])seq.get(i);
                    yy[i] = (Integer)label.get(i);
                }
                x.add(xx);
                y.add(yy);
                seq = new ArrayList();
                label = new ArrayList();
                seq.add(feature);
                label.add(Integer.valueOf(words[len + 3]));
            }
            int[][] xx = new int[seq.size()][];
            int[] yy = new int[seq.size()];
            for (int i = 0; i < seq.size(); ++i) {
                xx[i] = (int[])seq.get(i);
                yy[i] = (Integer)label.get(i);
            }
            x.add(xx);
            y.add(yy);
        }
        catch (IOException ex) {
            System.err.println(ex);
        }
        dataset.p = p;
        dataset.k = k;
        dataset.x = new int[x.size()][][];
        dataset.y = new int[y.size()][];
        for (int i = 0; i < dataset.x.length; ++i) {
            dataset.x[i] = (int[][])x.get(i);
            dataset.y[i] = (int[])y.get(i);
        }
        return dataset;
    }

    Dataset load(String resource, Attribute[] attributes) {
        int p = 0;
        int k = 0;
        Dataset dataset = new Dataset();
        dataset.attributes = attributes;
        ArrayList<double[][]> x = new ArrayList<double[][]>();
        ArrayList<int[]> y = new ArrayList<int[]>();
        ArrayList<double[]> seq = new ArrayList<double[]>();
        ArrayList<Integer> label = new ArrayList<Integer>();
        int id = 1;
        try (BufferedReader input = IOUtils.getTestDataReader(resource);){
            String[] words = input.readLine().split(" ");
            int nseq = Integer.parseInt(words[0]);
            k = Integer.parseInt(words[1]);
            p = Integer.parseInt(words[2]);
            String line = null;
            while ((line = input.readLine()) != null) {
                words = line.split(" ");
                int seqid = Integer.parseInt(words[0]);
                int pos = Integer.parseInt(words[1]);
                int len = Integer.parseInt(words[2]);
                if (dataset.attributes == null) {
                    dataset.attributes = new Attribute[len];
                    for (int i = 0; i < len; ++i) {
                        dataset.attributes[i] = new NominalAttribute("Attr" + (i + 1));
                    }
                }
                double[] feature = new double[len];
                for (int i = 0; i < len; ++i) {
                    try {
                        feature[i] = dataset.attributes[i].valueOf(words[i + 3]);
                        continue;
                    }
                    catch (ParseException ex) {
                        System.err.println(ex);
                    }
                }
                if (seqid == id) {
                    seq.add(feature);
                    label.add(Integer.valueOf(words[len + 3]));
                    continue;
                }
                id = seqid;
                double[][] xx = new double[seq.size()][];
                int[] yy = new int[seq.size()];
                for (int i = 0; i < seq.size(); ++i) {
                    xx[i] = (double[])seq.get(i);
                    yy[i] = (Integer)label.get(i);
                }
                x.add(xx);
                y.add(yy);
                seq = new ArrayList();
                label = new ArrayList();
                seq.add(feature);
                label.add(Integer.valueOf(words[len + 3]));
            }
            double[][] xx = new double[seq.size()][];
            int[] yy = new int[seq.size()];
            for (int i = 0; i < seq.size(); ++i) {
                xx[i] = (double[])seq.get(i);
                yy[i] = (Integer)label.get(i);
            }
            x.add(xx);
            y.add(yy);
        }
        catch (IOException ex) {
            System.err.println(ex);
        }
        dataset.p = p;
        dataset.k = k;
        dataset.x = new double[x.size()][][];
        dataset.y = new int[y.size()][];
        for (int i = 0; i < dataset.x.length; ++i) {
            dataset.x[i] = (double[][])x.get(i);
            dataset.y[i] = (int[])y.get(i);
        }
        return dataset;
    }

    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
        Math.setSeed(54217137L);
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testLearnProteinSparse() {
        System.out.println("learn protein sparse");
        IntDataset train = this.load("sequence/sparse.protein.11.train");
        IntDataset test = this.load("sequence/sparse.protein.11.test");
        CRF.Trainer trainer = new CRF.Trainer(train.p, train.k);
        trainer.setLearningRate(0.3);
        trainer.setMaxNodes(100);
        trainer.setNumTrees(100);
        CRF crf = trainer.train(train.x, train.y);
        int error = 0;
        int n = 0;
        for (int i = 0; i < test.x.length; ++i) {
            n += test.x[i].length;
            int[] label = crf.predict(test.x[i]);
            for (int j = 0; j < test.x[i].length; ++j) {
                if (test.y[i][j] == label[j]) continue;
                ++error;
            }
        }
        int viterbiError = 0;
        crf.setViterbi(true);
        for (int i = 0; i < test.x.length; ++i) {
            n += test.x[i].length;
            int[] label = crf.predict(test.x[i]);
            for (int j = 0; j < test.x[i].length; ++j) {
                if (test.y[i][j] == label[j]) continue;
                ++viterbiError;
            }
        }
        System.out.format("Protein error (forward-backward) is %d of %d%n", error, n);
        System.out.format("Protein error (forward-backward) rate = %.2f%%%n", 100.0 * (double)error / (double)n);
        System.out.format("Protein error (Viterbi) is %d of %d%n", viterbiError, n);
        System.out.format("Protein error (Viterbi) rate = %.2f%%%n", 100.0 * (double)viterbiError / (double)n);
        Assert.assertEquals(1234L, error);
        Assert.assertEquals(1318L, viterbiError);
    }

    @Test
    public void testLearnHyphenSparse() {
        System.out.println("learn hyphen sparse");
        IntDataset train = this.load("sequence/sparse.hyphen.6.train");
        IntDataset test = this.load("sequence/sparse.hyphen.6.test");
        CRF.Trainer trainer = new CRF.Trainer(train.p, train.k);
        trainer.setLearningRate(1.0);
        trainer.setMaxNodes(100);
        trainer.setNumTrees(100);
        CRF crf = trainer.train(train.x, train.y);
        int error = 0;
        int n = 0;
        for (int i = 0; i < test.x.length; ++i) {
            n += test.x[i].length;
            int[] label = crf.predict(test.x[i]);
            for (int j = 0; j < test.x[i].length; ++j) {
                if (test.y[i][j] == label[j]) continue;
                ++error;
            }
        }
        int viterbiError = 0;
        crf.setViterbi(true);
        for (int i = 0; i < test.x.length; ++i) {
            n += test.x[i].length;
            int[] label = crf.predict(test.x[i]);
            for (int j = 0; j < test.x[i].length; ++j) {
                if (test.y[i][j] == label[j]) continue;
                ++viterbiError;
            }
        }
        System.out.format("Hypen error (forward-backward) is %d of %d%n", error, n);
        System.out.format("Hypen error (forward-backward) rate = %.2f%%%n", 100.0 * (double)error / (double)n);
        System.out.format("Hypen error (Viterbi) is %d of %d%n", viterbiError, n);
        System.out.format("Hypen error (Viterbi) rate = %.2f%%%n", 100.0 * (double)viterbiError / (double)n);
        Assert.assertEquals(470L, error);
        Assert.assertEquals(478L, viterbiError);
    }

    @Test
    public void testLearnProtein() {
        System.out.println("learn protein");
        Dataset train = this.load("sequence/sparse.protein.11.train", null);
        Dataset test = this.load("sequence/sparse.protein.11.test", train.attributes);
        CRF.Trainer trainer = new CRF.Trainer(train.attributes, train.k);
        trainer.setLearningRate(0.3);
        trainer.setMaxNodes(100);
        trainer.setNumTrees(100);
        CRF crf = trainer.train(train.x, train.y);
        int error = 0;
        int n = 0;
        for (int i = 0; i < test.x.length; ++i) {
            n += test.x[i].length;
            int[] label = crf.predict(test.x[i]);
            for (int j = 0; j < test.x[i].length; ++j) {
                if (test.y[i][j] == label[j]) continue;
                ++error;
            }
        }
        int viterbiError = 0;
        crf.setViterbi(true);
        for (int i = 0; i < test.x.length; ++i) {
            n += test.x[i].length;
            int[] label = crf.predict(test.x[i]);
            for (int j = 0; j < test.x[i].length; ++j) {
                if (test.y[i][j] == label[j]) continue;
                ++viterbiError;
            }
        }
        System.out.format("Protein error (forward-backward) is %d of %d%n", error, n);
        System.out.format("Protein error (forward-backward) rate = %.2f%%%n", 100.0 * (double)error / (double)n);
        System.out.format("Protein error (Viterbi) is %d of %d%n", viterbiError, n);
        System.out.format("Protein error (Viterbi) rate = %.2f%%%n", 100.0 * (double)viterbiError / (double)n);
        Assert.assertEquals(1270L, error);
        Assert.assertEquals(1420L, viterbiError);
    }

    @Test
    public void testLearnHyphen() {
        System.out.println("learn hyphen");
        Dataset train = this.load("sequence/sparse.hyphen.6.train", null);
        Dataset test = this.load("sequence/sparse.hyphen.6.test", train.attributes);
        CRF.Trainer trainer = new CRF.Trainer(train.attributes, train.k);
        trainer.setLearningRate(1.0);
        trainer.setMaxNodes(100);
        trainer.setNumTrees(100);
        CRF crf = trainer.train(train.x, train.y);
        int error = 0;
        int n = 0;
        for (int i = 0; i < test.x.length; ++i) {
            n += test.x[i].length;
            int[] label = crf.predict(test.x[i]);
            for (int j = 0; j < test.x[i].length; ++j) {
                if (test.y[i][j] == label[j]) continue;
                ++error;
            }
        }
        int viterbiError = 0;
        crf.setViterbi(true);
        for (int i = 0; i < test.x.length; ++i) {
            n += test.x[i].length;
            int[] label = crf.predict(test.x[i]);
            for (int j = 0; j < test.x[i].length; ++j) {
                if (test.y[i][j] == label[j]) continue;
                ++viterbiError;
            }
        }
        System.out.format("Hypen error (forward-backward) is %d of %d%n", error, n);
        System.out.format("Hypen error (forward-backward) rate = %.2f%%%n", 100.0 * (double)error / (double)n);
        System.out.format("Hypen error (Viterbi) is %d of %d%n", viterbiError, n);
        System.out.format("Hypen error (Viterbi) rate = %.2f%%%n", 100.0 * (double)viterbiError / (double)n);
        Assert.assertEquals(473L, error);
        Assert.assertEquals(478L, viterbiError);
    }

    class IntDataset {
        int[][][] x;
        int[][] y;
        int p;
        int k;

        IntDataset() {
        }
    }

    class Dataset {
        Attribute[] attributes;
        double[][][] x;
        int[][] y;
        int p;
        int k;

        Dataset() {
        }
    }
}

