package smile.sequence;

import java.io.BufferedReader;
import java.io.IOException;
import java.text.ParseException;
import java.util.ArrayList;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.data.Attribute;
import smile.data.NominalAttribute;
import smile.data.parser.IOUtils;
import smile.math.Math;
import smile.sequence.CRF;

/* loaded from: input_file:smile/sequence/CRFTest.class */
public class CRFTest {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/sequence/CRFTest$Dataset.class */
    public class Dataset {
        Attribute[] attributes;
        double[][][] x;
        int[][] y;
        int p;
        int k;

        Dataset() {
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/sequence/CRFTest$IntDataset.class */
    public class IntDataset {
        int[][][] x;
        int[][] y;
        int p;
        int k;

        IntDataset() {
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v10, types: [int[][], int[][][]] */
    /* JADX WARN: Type inference failed for: r1v13, types: [int[], int[][]] */
    IntDataset load(String str) {
        int i = 0;
        int i2 = 0;
        IntDataset intDataset = new IntDataset();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        int i3 = 1;
        try {
            BufferedReader testDataReader = IOUtils.getTestDataReader(str);
            Throwable th = null;
            try {
                try {
                    String[] split = testDataReader.readLine().split(" ");
                    Integer.parseInt(split[0]);
                    i2 = Integer.parseInt(split[1]);
                    i = Integer.parseInt(split[2]);
                    while (true) {
                        String readLine = testDataReader.readLine();
                        if (readLine == null) {
                            break;
                        }
                        String[] split2 = readLine.split(" ");
                        int parseInt = Integer.parseInt(split2[0]);
                        Integer.parseInt(split2[1]);
                        int parseInt2 = Integer.parseInt(split2[2]);
                        int[] iArr = new int[parseInt2];
                        for (int i4 = 0; i4 < parseInt2; i4++) {
                            try {
                                iArr[i4] = Integer.parseInt(split2[i4 + 3]);
                            } catch (Exception e) {
                                System.err.println(e);
                            }
                        }
                        if (parseInt == i3) {
                            arrayList3.add(iArr);
                            arrayList4.add(Integer.valueOf(split2[parseInt2 + 3]));
                        } else {
                            i3 = parseInt;
                            int[] iArr2 = new int[arrayList3.size()];
                            int[] iArr3 = new int[arrayList3.size()];
                            for (int i5 = 0; i5 < arrayList3.size(); i5++) {
                                iArr2[i5] = (int[]) arrayList3.get(i5);
                                iArr3[i5] = ((Integer) arrayList4.get(i5)).intValue();
                            }
                            arrayList.add(iArr2);
                            arrayList2.add(iArr3);
                            arrayList3 = new ArrayList();
                            arrayList4 = new ArrayList();
                            arrayList3.add(iArr);
                            arrayList4.add(Integer.valueOf(split2[parseInt2 + 3]));
                        }
                    }
                    int[] iArr4 = new int[arrayList3.size()];
                    int[] iArr5 = new int[arrayList3.size()];
                    for (int i6 = 0; i6 < arrayList3.size(); i6++) {
                        iArr4[i6] = (int[]) arrayList3.get(i6);
                        iArr5[i6] = ((Integer) arrayList4.get(i6)).intValue();
                    }
                    arrayList.add(iArr4);
                    arrayList2.add(iArr5);
                    if (testDataReader != null) {
                        if (0 != 0) {
                            try {
                                testDataReader.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            testDataReader.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (IOException e2) {
            System.err.println(e2);
        }
        intDataset.p = i;
        intDataset.k = i2;
        intDataset.x = new int[arrayList.size()];
        intDataset.y = new int[arrayList2.size()];
        for (int i7 = 0; i7 < intDataset.x.length; i7++) {
            intDataset.x[i7] = (int[][]) arrayList.get(i7);
            intDataset.y[i7] = (int[]) arrayList2.get(i7);
        }
        return intDataset;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v11, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v14, types: [int[], int[][]] */
    Dataset load(String str, Attribute[] attributeArr) {
        BufferedReader testDataReader;
        Throwable th;
        int i = 0;
        int i2 = 0;
        Dataset dataset = new Dataset();
        dataset.attributes = attributeArr;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        int i3 = 1;
        try {
            testDataReader = IOUtils.getTestDataReader(str);
            th = null;
        } catch (IOException e) {
            System.err.println(e);
        }
        try {
            try {
                String[] split = testDataReader.readLine().split(" ");
                Integer.parseInt(split[0]);
                i2 = Integer.parseInt(split[1]);
                i = Integer.parseInt(split[2]);
                while (true) {
                    String readLine = testDataReader.readLine();
                    if (readLine == null) {
                        break;
                    }
                    String[] split2 = readLine.split(" ");
                    int parseInt = Integer.parseInt(split2[0]);
                    Integer.parseInt(split2[1]);
                    int parseInt2 = Integer.parseInt(split2[2]);
                    if (dataset.attributes == null) {
                        dataset.attributes = new Attribute[parseInt2];
                        for (int i4 = 0; i4 < parseInt2; i4++) {
                            dataset.attributes[i4] = new NominalAttribute("Attr" + (i4 + 1));
                        }
                    }
                    double[] dArr = new double[parseInt2];
                    for (int i5 = 0; i5 < parseInt2; i5++) {
                        try {
                            dArr[i5] = dataset.attributes[i5].valueOf(split2[i5 + 3]);
                        } catch (ParseException e2) {
                            System.err.println(e2);
                        }
                    }
                    if (parseInt == i3) {
                        arrayList3.add(dArr);
                        arrayList4.add(Integer.valueOf(split2[parseInt2 + 3]));
                    } else {
                        i3 = parseInt;
                        double[] dArr2 = new double[arrayList3.size()];
                        int[] iArr = new int[arrayList3.size()];
                        for (int i6 = 0; i6 < arrayList3.size(); i6++) {
                            dArr2[i6] = (double[]) arrayList3.get(i6);
                            iArr[i6] = ((Integer) arrayList4.get(i6)).intValue();
                        }
                        arrayList.add(dArr2);
                        arrayList2.add(iArr);
                        arrayList3 = new ArrayList();
                        arrayList4 = new ArrayList();
                        arrayList3.add(dArr);
                        arrayList4.add(Integer.valueOf(split2[parseInt2 + 3]));
                    }
                }
                double[] dArr3 = new double[arrayList3.size()];
                int[] iArr2 = new int[arrayList3.size()];
                for (int i7 = 0; i7 < arrayList3.size(); i7++) {
                    dArr3[i7] = (double[]) arrayList3.get(i7);
                    iArr2[i7] = ((Integer) arrayList4.get(i7)).intValue();
                }
                arrayList.add(dArr3);
                arrayList2.add(iArr2);
                if (testDataReader != null) {
                    if (0 != 0) {
                        try {
                            testDataReader.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        testDataReader.close();
                    }
                }
                dataset.p = i;
                dataset.k = i2;
                dataset.x = new double[arrayList.size()];
                dataset.y = new int[arrayList2.size()];
                for (int i8 = 0; i8 < dataset.x.length; i8++) {
                    dataset.x[i8] = (double[][]) arrayList.get(i8);
                    dataset.y[i8] = (int[]) arrayList2.get(i8);
                }
                return dataset;
            } finally {
            }
        } finally {
        }
    }

    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
        Math.setSeed(54217137L);
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testLearnProteinSparse() {
        System.out.println("learn protein sparse");
        IntDataset load = load("sequence/sparse.protein.11.train");
        IntDataset load2 = load("sequence/sparse.protein.11.test");
        CRF.Trainer trainer = new CRF.Trainer(load.p, load.k);
        trainer.setLearningRate(0.3d);
        trainer.setMaxNodes(100);
        trainer.setNumTrees(100);
        CRF train = trainer.train(load.x, load.y);
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < load2.x.length; i3++) {
            i2 += load2.x[i3].length;
            int[] predict = train.predict(load2.x[i3]);
            for (int i4 = 0; i4 < load2.x[i3].length; i4++) {
                if (load2.y[i3][i4] != predict[i4]) {
                    i++;
                }
            }
        }
        int i5 = 0;
        train.setViterbi(true);
        for (int i6 = 0; i6 < load2.x.length; i6++) {
            i2 += load2.x[i6].length;
            int[] predict2 = train.predict(load2.x[i6]);
            for (int i7 = 0; i7 < load2.x[i6].length; i7++) {
                if (load2.y[i6][i7] != predict2[i7]) {
                    i5++;
                }
            }
        }
        System.out.format("Protein error (forward-backward) is %d of %d%n", Integer.valueOf(i), Integer.valueOf(i2));
        System.out.format("Protein error (forward-backward) rate = %.2f%%%n", Double.valueOf((100.0d * i) / i2));
        System.out.format("Protein error (Viterbi) is %d of %d%n", Integer.valueOf(i5), Integer.valueOf(i2));
        System.out.format("Protein error (Viterbi) rate = %.2f%%%n", Double.valueOf((100.0d * i5) / i2));
        Assert.assertEquals(1234L, i);
        Assert.assertEquals(1318L, i5);
    }

    @Test
    public void testLearnHyphenSparse() {
        System.out.println("learn hyphen sparse");
        IntDataset load = load("sequence/sparse.hyphen.6.train");
        IntDataset load2 = load("sequence/sparse.hyphen.6.test");
        CRF.Trainer trainer = new CRF.Trainer(load.p, load.k);
        trainer.setLearningRate(1.0d);
        trainer.setMaxNodes(100);
        trainer.setNumTrees(100);
        CRF train = trainer.train(load.x, load.y);
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < load2.x.length; i3++) {
            i2 += load2.x[i3].length;
            int[] predict = train.predict(load2.x[i3]);
            for (int i4 = 0; i4 < load2.x[i3].length; i4++) {
                if (load2.y[i3][i4] != predict[i4]) {
                    i++;
                }
            }
        }
        int i5 = 0;
        train.setViterbi(true);
        for (int i6 = 0; i6 < load2.x.length; i6++) {
            i2 += load2.x[i6].length;
            int[] predict2 = train.predict(load2.x[i6]);
            for (int i7 = 0; i7 < load2.x[i6].length; i7++) {
                if (load2.y[i6][i7] != predict2[i7]) {
                    i5++;
                }
            }
        }
        System.out.format("Hypen error (forward-backward) is %d of %d%n", Integer.valueOf(i), Integer.valueOf(i2));
        System.out.format("Hypen error (forward-backward) rate = %.2f%%%n", Double.valueOf((100.0d * i) / i2));
        System.out.format("Hypen error (Viterbi) is %d of %d%n", Integer.valueOf(i5), Integer.valueOf(i2));
        System.out.format("Hypen error (Viterbi) rate = %.2f%%%n", Double.valueOf((100.0d * i5) / i2));
        Assert.assertEquals(470L, i);
        Assert.assertEquals(478L, i5);
    }

    @Test
    public void testLearnProtein() {
        System.out.println("learn protein");
        Dataset load = load("sequence/sparse.protein.11.train", null);
        Dataset load2 = load("sequence/sparse.protein.11.test", load.attributes);
        CRF.Trainer trainer = new CRF.Trainer(load.attributes, load.k);
        trainer.setLearningRate(0.3d);
        trainer.setMaxNodes(100);
        trainer.setNumTrees(100);
        CRF train = trainer.train(load.x, load.y);
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < load2.x.length; i3++) {
            i2 += load2.x[i3].length;
            int[] predict = train.predict(load2.x[i3]);
            for (int i4 = 0; i4 < load2.x[i3].length; i4++) {
                if (load2.y[i3][i4] != predict[i4]) {
                    i++;
                }
            }
        }
        int i5 = 0;
        train.setViterbi(true);
        for (int i6 = 0; i6 < load2.x.length; i6++) {
            i2 += load2.x[i6].length;
            int[] predict2 = train.predict(load2.x[i6]);
            for (int i7 = 0; i7 < load2.x[i6].length; i7++) {
                if (load2.y[i6][i7] != predict2[i7]) {
                    i5++;
                }
            }
        }
        System.out.format("Protein error (forward-backward) is %d of %d%n", Integer.valueOf(i), Integer.valueOf(i2));
        System.out.format("Protein error (forward-backward) rate = %.2f%%%n", Double.valueOf((100.0d * i) / i2));
        System.out.format("Protein error (Viterbi) is %d of %d%n", Integer.valueOf(i5), Integer.valueOf(i2));
        System.out.format("Protein error (Viterbi) rate = %.2f%%%n", Double.valueOf((100.0d * i5) / i2));
        Assert.assertEquals(1270L, i);
        Assert.assertEquals(1420L, i5);
    }

    @Test
    public void testLearnHyphen() {
        System.out.println("learn hyphen");
        Dataset load = load("sequence/sparse.hyphen.6.train", null);
        Dataset load2 = load("sequence/sparse.hyphen.6.test", load.attributes);
        CRF.Trainer trainer = new CRF.Trainer(load.attributes, load.k);
        trainer.setLearningRate(1.0d);
        trainer.setMaxNodes(100);
        trainer.setNumTrees(100);
        CRF train = trainer.train(load.x, load.y);
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < load2.x.length; i3++) {
            i2 += load2.x[i3].length;
            int[] predict = train.predict(load2.x[i3]);
            for (int i4 = 0; i4 < load2.x[i3].length; i4++) {
                if (load2.y[i3][i4] != predict[i4]) {
                    i++;
                }
            }
        }
        int i5 = 0;
        train.setViterbi(true);
        for (int i6 = 0; i6 < load2.x.length; i6++) {
            i2 += load2.x[i6].length;
            int[] predict2 = train.predict(load2.x[i6]);
            for (int i7 = 0; i7 < load2.x[i6].length; i7++) {
                if (load2.y[i6][i7] != predict2[i7]) {
                    i5++;
                }
            }
        }
        System.out.format("Hypen error (forward-backward) is %d of %d%n", Integer.valueOf(i), Integer.valueOf(i2));
        System.out.format("Hypen error (forward-backward) rate = %.2f%%%n", Double.valueOf((100.0d * i) / i2));
        System.out.format("Hypen error (Viterbi) is %d of %d%n", Integer.valueOf(i5), Integer.valueOf(i2));
        System.out.format("Hypen error (Viterbi) rate = %.2f%%%n", Double.valueOf((100.0d * i5) / i2));
        Assert.assertEquals(473L, i);
        Assert.assertEquals(478L, i5);
    }
}
