package smile.classification;

import java.io.BufferedReader;
import java.io.IOException;
import java.util.ArrayList;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.data.parser.IOUtils;

/* loaded from: input_file:smile/classification/MaxentTest.class */
public class MaxentTest {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/classification/MaxentTest$Dataset.class */
    public class Dataset {
        int[][] x;
        int[] y;
        int p;

        Dataset() {
        }
    }

    /* JADX WARN: Finally extract failed */
    /* JADX WARN: Type inference failed for: r1v7, types: [int[], int[][]] */
    Dataset load(String str) {
        int i = 0;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        try {
            BufferedReader testDataReader = IOUtils.getTestDataReader(str);
            Throwable th = null;
            try {
                String[] split = testDataReader.readLine().split(" ");
                Integer.parseInt(split[0]);
                Integer.parseInt(split[1]);
                i = Integer.parseInt(split[2]);
                while (true) {
                    String readLine = testDataReader.readLine();
                    if (readLine == null) {
                        break;
                    }
                    String[] split2 = readLine.split(" ");
                    Integer.parseInt(split2[0]);
                    Integer.parseInt(split2[1]);
                    int parseInt = Integer.parseInt(split2[2]);
                    int[] iArr = new int[parseInt];
                    for (int i2 = 0; i2 < parseInt; i2++) {
                        iArr[i2] = Integer.parseInt(split2[i2 + 3]);
                    }
                    arrayList.add(iArr);
                    arrayList2.add(Integer.valueOf(split2[parseInt + 3]));
                }
                if (testDataReader != null) {
                    if (0 != 0) {
                        try {
                            testDataReader.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        testDataReader.close();
                    }
                }
            } catch (Throwable th3) {
                if (testDataReader != null) {
                    if (0 != 0) {
                        try {
                            testDataReader.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        testDataReader.close();
                    }
                }
                throw th3;
            }
        } catch (IOException e) {
            System.err.println(e);
        }
        Dataset dataset = new Dataset();
        dataset.p = i;
        dataset.x = new int[arrayList.size()];
        dataset.y = new int[arrayList2.size()];
        for (int i3 = 0; i3 < dataset.x.length; i3++) {
            dataset.x[i3] = (int[]) arrayList.get(i3);
            dataset.y[i3] = ((Integer) arrayList2.get(i3)).intValue();
        }
        return dataset;
    }

    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testLearnProtein() {
        System.out.println("learn protein");
        Dataset load = load("sequence/sparse.protein.11.train");
        Dataset load2 = load("sequence/sparse.protein.11.test");
        Maxent maxent = new Maxent(load.p, load.x, load.y, 0.1d, 1.0E-5d, 500);
        int i = 0;
        for (int i2 = 0; i2 < load2.x.length; i2++) {
            if (load2.y[i2] != maxent.predict(load2.x[i2])) {
                i++;
            }
        }
        System.out.format("Protein error is %d of %d%n", Integer.valueOf(i), Integer.valueOf(load2.x.length));
        System.out.format("Protein error rate = %.2f%%%n", Double.valueOf((100.0d * i) / load2.x.length));
        Assert.assertEquals(1338L, i);
    }

    @Test
    public void testLearnHyphen() {
        System.out.println("learn hyphen");
        Dataset load = load("sequence/sparse.hyphen.6.train");
        Dataset load2 = load("sequence/sparse.hyphen.6.test");
        Maxent maxent = new Maxent(load.p, load.x, load.y, 0.1d, 1.0E-5d, 500);
        int i = 0;
        for (int i2 = 0; i2 < load2.x.length; i2++) {
            if (load2.y[i2] != maxent.predict(load2.x[i2])) {
                i++;
            }
        }
        System.out.format("Protein error is %d of %d%n", Integer.valueOf(i), Integer.valueOf(load2.x.length));
        System.out.format("Hyphen error rate = %.2f%%%n", Double.valueOf((100.0d * i) / load2.x.length));
        Assert.assertEquals(765L, i);
    }
}
