/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.BufferedReader;
import java.io.IOException;
import java.util.ArrayList;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.classification.Maxent;
import smile.data.parser.IOUtils;

public class MaxentTest {
    Dataset load(String resource) {
        int p = 0;
        ArrayList<int[]> x = new ArrayList<int[]>();
        ArrayList<Integer> y = new ArrayList<Integer>();
        try (BufferedReader input = IOUtils.getTestDataReader(resource);){
            String[] words = input.readLine().split(" ");
            int nseq = Integer.parseInt(words[0]);
            int k = Integer.parseInt(words[1]);
            p = Integer.parseInt(words[2]);
            String line = null;
            while ((line = input.readLine()) != null) {
                words = line.split(" ");
                int seqid = Integer.parseInt(words[0]);
                int pos = Integer.parseInt(words[1]);
                int len = Integer.parseInt(words[2]);
                int[] feature = new int[len];
                for (int i = 0; i < len; ++i) {
                    feature[i] = Integer.parseInt(words[i + 3]);
                }
                x.add(feature);
                y.add(Integer.valueOf(words[len + 3]));
            }
        }
        catch (IOException ex) {
            System.err.println(ex);
        }
        Dataset dataset = new Dataset();
        dataset.p = p;
        dataset.x = new int[x.size()][];
        dataset.y = new int[y.size()];
        for (int i = 0; i < dataset.x.length; ++i) {
            dataset.x[i] = (int[])x.get(i);
            dataset.y[i] = (Integer)y.get(i);
        }
        return dataset;
    }

    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testLearnProtein() {
        System.out.println("learn protein");
        Dataset train = this.load("sequence/sparse.protein.11.train");
        Dataset test = this.load("sequence/sparse.protein.11.test");
        Maxent maxent = new Maxent(train.p, train.x, train.y, 0.1, 1.0E-5, 500);
        int error = 0;
        for (int i = 0; i < test.x.length; ++i) {
            if (test.y[i] == maxent.predict(test.x[i])) continue;
            ++error;
        }
        System.out.format("Protein error is %d of %d%n", error, test.x.length);
        System.out.format("Protein error rate = %.2f%%%n", 100.0 * (double)error / (double)test.x.length);
        Assert.assertEquals(1338L, error);
    }

    @Test
    public void testLearnHyphen() {
        System.out.println("learn hyphen");
        Dataset train = this.load("sequence/sparse.hyphen.6.train");
        Dataset test = this.load("sequence/sparse.hyphen.6.test");
        Maxent maxent = new Maxent(train.p, train.x, train.y, 0.1, 1.0E-5, 500);
        int error = 0;
        for (int i = 0; i < test.x.length; ++i) {
            if (test.y[i] == maxent.predict(test.x[i])) continue;
            ++error;
        }
        System.out.format("Protein error is %d of %d%n", error, test.x.length);
        System.out.format("Hyphen error rate = %.2f%%%n", 100.0 * (double)error / (double)test.x.length);
        Assert.assertEquals(765L, error);
    }

    class Dataset {
        int[][] x;
        int[] y;
        int p;

        Dataset() {
        }
    }
}

