package smile.classification;

import java.io.BufferedReader;
import java.io.IOException;
import java.util.ArrayList;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.classification.NaiveBayes;
import smile.data.AttributeDataset;
import smile.data.parser.ArffParser;
import smile.data.parser.IOUtils;
import smile.feature.Bag;
import smile.math.Math;
import smile.stat.distribution.Distribution;
import smile.stat.distribution.GaussianMixture;
import smile.validation.CrossValidation;
import smile.validation.LOOCV;

/* loaded from: input_file:smile/classification/NaiveBayesTest.class */
public class NaiveBayesTest {
    String[] feature = {"outstanding", "wonderfully", "wasted", "lame", "awful", "poorly", "ridiculous", "waste", "worst", "bland", "unfunny", "stupid", "dull", "fantastic", "laughable", "mess", "pointless", "terrific", "memorable", "superb", "boring", "badly", "subtle", "terrible", "excellent", "perfectly", "masterpiece", "realistic", "flaws"};
    double[][] moviex;
    int[] moviey;

    /* JADX WARN: Type inference failed for: r1v5, types: [double[], double[][]] */
    public NaiveBayesTest() {
        Object[] objArr = new String[2000];
        int[] iArr = new int[2000];
        try {
            BufferedReader testDataReader = IOUtils.getTestDataReader("text/movie.txt");
            Throwable th = null;
            for (int i = 0; i < objArr.length; i++) {
                try {
                    try {
                        String[] split = testDataReader.readLine().trim().split(" ");
                        if (split[0].equalsIgnoreCase("pos")) {
                            iArr[i] = 1;
                        } else if (split[0].equalsIgnoreCase("neg")) {
                            iArr[i] = 0;
                        } else {
                            System.err.println("Invalid class label: " + split[0]);
                        }
                        objArr[i] = split;
                    } finally {
                    }
                } finally {
                }
            }
            if (testDataReader != null) {
                if (0 != 0) {
                    try {
                        testDataReader.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    testDataReader.close();
                }
            }
        } catch (IOException e) {
            System.err.println(e);
        }
        this.moviex = new double[objArr.length];
        this.moviey = new int[iArr.length];
        Bag bag = new Bag(this.feature);
        for (int i2 = 0; i2 < objArr.length; i2++) {
            this.moviex[i2] = bag.feature((Object[]) objArr[i2]);
            this.moviey[i2] = iArr[i2];
        }
    }

    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testPredict() {
        System.out.println("predict");
        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(4);
        try {
            AttributeDataset parse = arffParser.parse(IOUtils.getTestDataFile("weka/iris.arff"));
            double[][] array = parse.toArray((Object[]) new double[parse.size()]);
            int[] array2 = parse.toArray(new int[parse.size()]);
            int length = array.length;
            LOOCV loocv = new LOOCV(length);
            int i = 0;
            for (int i2 = 0; i2 < length; i2++) {
                double[][] dArr = (double[][]) Math.slice(array, loocv.train[i2]);
                int[] slice = Math.slice(array2, loocv.train[i2]);
                int length2 = dArr[0].length;
                int max = Math.max(slice) + 1;
                double[] dArr2 = new double[max];
                Distribution[][] distributionArr = new Distribution[max][length2];
                for (int i3 = 0; i3 < max; i3++) {
                    dArr2[i3] = 1.0d / max;
                    for (int i4 = 0; i4 < length2; i4++) {
                        ArrayList arrayList = new ArrayList();
                        for (int i5 = 0; i5 < dArr.length; i5++) {
                            if (slice[i5] == i3) {
                                arrayList.add(Double.valueOf(dArr[i5][i4]));
                            }
                        }
                        double[] dArr3 = new double[arrayList.size()];
                        for (int i6 = 0; i6 < dArr3.length; i6++) {
                            dArr3[i6] = ((Double) arrayList.get(i6)).doubleValue();
                        }
                        distributionArr[i3][i4] = new GaussianMixture(dArr3, 3);
                    }
                }
                if (array2[loocv.test[i2]] != new NaiveBayes(dArr2, distributionArr).predict(array[loocv.test[i2]])) {
                    i++;
                }
            }
            System.out.format("Iris error rate = %.2f%%%n", Double.valueOf((100.0d * i) / array.length));
            Assert.assertEquals(8L, i);
        } catch (Exception e) {
            System.err.println(e);
        }
    }

    @Test
    public void testLearnMultinomial() {
        System.out.println("batch learn Multinomial");
        double[][] dArr = this.moviex;
        int[] iArr = this.moviey;
        CrossValidation crossValidation = new CrossValidation(dArr.length, 10);
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < 10; i3++) {
            double[][] dArr2 = (double[][]) Math.slice(dArr, crossValidation.train[i3]);
            int[] slice = Math.slice(iArr, crossValidation.train[i3]);
            NaiveBayes naiveBayes = new NaiveBayes(NaiveBayes.Model.MULTINOMIAL, 2, this.feature.length);
            naiveBayes.learn(dArr2, slice);
            double[][] dArr3 = (double[][]) Math.slice(dArr, crossValidation.test[i3]);
            int[] slice2 = Math.slice(iArr, crossValidation.test[i3]);
            for (int i4 = 0; i4 < dArr3.length; i4++) {
                int predict = naiveBayes.predict(dArr3[i4]);
                if (predict != -1) {
                    i2++;
                    if (slice2[i4] != predict) {
                        i++;
                    }
                }
            }
        }
        System.out.format("Multinomial error = %d of %d%n", Integer.valueOf(i), Integer.valueOf(i2));
        Assert.assertTrue(i < 265);
    }

    @Test
    public void testLearnMultinomial2() {
        System.out.println("online learn Multinomial");
        double[][] dArr = this.moviex;
        int[] iArr = this.moviey;
        CrossValidation crossValidation = new CrossValidation(dArr.length, 10);
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < 10; i3++) {
            double[][] dArr2 = (double[][]) Math.slice(dArr, crossValidation.train[i3]);
            int[] slice = Math.slice(iArr, crossValidation.train[i3]);
            NaiveBayes naiveBayes = new NaiveBayes(NaiveBayes.Model.MULTINOMIAL, 2, this.feature.length);
            for (int i4 = 0; i4 < dArr2.length; i4++) {
                naiveBayes.learn(dArr2[i4], slice[i4]);
            }
            double[][] dArr3 = (double[][]) Math.slice(dArr, crossValidation.test[i3]);
            int[] slice2 = Math.slice(iArr, crossValidation.test[i3]);
            for (int i5 = 0; i5 < dArr3.length; i5++) {
                int predict = naiveBayes.predict(dArr3[i5]);
                if (predict != -1) {
                    i2++;
                    if (slice2[i5] != predict) {
                        i++;
                    }
                }
            }
        }
        System.out.format("Multinomial error = %d of %d%n", Integer.valueOf(i), Integer.valueOf(i2));
        Assert.assertTrue(i < 265);
    }

    @Test
    public void testLearnBernoulli() {
        System.out.println("batch learn Bernoulli");
        double[][] dArr = this.moviex;
        int[] iArr = this.moviey;
        CrossValidation crossValidation = new CrossValidation(dArr.length, 10);
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < 10; i3++) {
            double[][] dArr2 = (double[][]) Math.slice(dArr, crossValidation.train[i3]);
            int[] slice = Math.slice(iArr, crossValidation.train[i3]);
            NaiveBayes naiveBayes = new NaiveBayes(NaiveBayes.Model.BERNOULLI, 2, this.feature.length);
            naiveBayes.learn(dArr2, slice);
            double[][] dArr3 = (double[][]) Math.slice(dArr, crossValidation.test[i3]);
            int[] slice2 = Math.slice(iArr, crossValidation.test[i3]);
            for (int i4 = 0; i4 < dArr3.length; i4++) {
                int predict = naiveBayes.predict(dArr3[i4]);
                if (predict != -1) {
                    i2++;
                    if (slice2[i4] != predict) {
                        i++;
                    }
                }
            }
        }
        System.out.format("Bernoulli error = %d of %d%n", Integer.valueOf(i), Integer.valueOf(i2));
        Assert.assertTrue(i < 270);
    }

    @Test
    public void testLearnBernoulli2() {
        System.out.println("online learn Bernoulli");
        double[][] dArr = this.moviex;
        int[] iArr = this.moviey;
        CrossValidation crossValidation = new CrossValidation(dArr.length, 10);
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < 10; i3++) {
            double[][] dArr2 = (double[][]) Math.slice(dArr, crossValidation.train[i3]);
            int[] slice = Math.slice(iArr, crossValidation.train[i3]);
            NaiveBayes naiveBayes = new NaiveBayes(NaiveBayes.Model.BERNOULLI, 2, this.feature.length);
            for (int i4 = 0; i4 < dArr2.length; i4++) {
                naiveBayes.learn(dArr2[i4], slice[i4]);
            }
            double[][] dArr3 = (double[][]) Math.slice(dArr, crossValidation.test[i3]);
            int[] slice2 = Math.slice(iArr, crossValidation.test[i3]);
            for (int i5 = 0; i5 < dArr3.length; i5++) {
                int predict = naiveBayes.predict(dArr3[i5]);
                if (predict != -1) {
                    i2++;
                    if (slice2[i5] != predict) {
                        i++;
                    }
                }
            }
        }
        System.out.format("Bernoulli error = %d of %d%n", Integer.valueOf(i), Integer.valueOf(i2));
        Assert.assertTrue(i < 270);
    }
}
