/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.BufferedReader;
import java.io.IOException;
import java.util.ArrayList;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.classification.NaiveBayes;
import smile.data.AttributeDataset;
import smile.data.parser.ArffParser;
import smile.data.parser.IOUtils;
import smile.feature.Bag;
import smile.math.Math;
import smile.stat.distribution.Distribution;
import smile.stat.distribution.GaussianMixture;
import smile.validation.CrossValidation;
import smile.validation.LOOCV;

public class NaiveBayesTest {
    String[] feature = new String[]{"outstanding", "wonderfully", "wasted", "lame", "awful", "poorly", "ridiculous", "waste", "worst", "bland", "unfunny", "stupid", "dull", "fantastic", "laughable", "mess", "pointless", "terrific", "memorable", "superb", "boring", "badly", "subtle", "terrible", "excellent", "perfectly", "masterpiece", "realistic", "flaws"};
    double[][] moviex;
    int[] moviey;

    public NaiveBayesTest() {
        String[][] x = new String[2000][];
        int[] y = new int[2000];
        try (BufferedReader input = IOUtils.getTestDataReader("text/movie.txt");){
            for (int i = 0; i < x.length; ++i) {
                String[] words = input.readLine().trim().split(" ");
                if (words[0].equalsIgnoreCase("pos")) {
                    y[i] = 1;
                } else if (words[0].equalsIgnoreCase("neg")) {
                    y[i] = 0;
                } else {
                    System.err.println("Invalid class label: " + words[0]);
                }
                x[i] = words;
            }
        }
        catch (IOException ex) {
            System.err.println(ex);
        }
        this.moviex = new double[x.length][];
        this.moviey = new int[y.length];
        Bag<String> bag = new Bag<String>(this.feature);
        for (int i = 0; i < x.length; ++i) {
            this.moviex[i] = bag.feature((T[])x[i]);
            this.moviey[i] = y[i];
        }
    }

    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testPredict() {
        System.out.println("predict");
        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(4);
        try {
            AttributeDataset iris = arffParser.parse(IOUtils.getTestDataFile("weka/iris.arff"));
            double[][] x = (double[][])iris.toArray((E[])new double[iris.size()][]);
            int[] y = iris.toArray(new int[iris.size()]);
            int n = x.length;
            LOOCV loocv = new LOOCV(n);
            int error = 0;
            for (int l = 0; l < n; ++l) {
                double[][] trainx = (double[][])Math.slice(x, loocv.train[l]);
                int[] trainy = Math.slice(y, loocv.train[l]);
                int p = trainx[0].length;
                int k = Math.max(trainy) + 1;
                double[] priori = new double[k];
                Distribution[][] condprob = new Distribution[k][p];
                for (int i = 0; i < k; ++i) {
                    priori[i] = 1.0 / (double)k;
                    for (int j = 0; j < p; ++j) {
                        ArrayList<Double> axi = new ArrayList<Double>();
                        for (int m = 0; m < trainx.length; ++m) {
                            if (trainy[m] != i) continue;
                            axi.add(trainx[m][j]);
                        }
                        double[] xi = new double[axi.size()];
                        for (int m = 0; m < xi.length; ++m) {
                            xi[m] = (Double)axi.get(m);
                        }
                        condprob[i][j] = new GaussianMixture(xi, 3);
                    }
                }
                NaiveBayes bayes = new NaiveBayes(priori, condprob);
                if (y[loocv.test[l]] == bayes.predict(x[loocv.test[l]])) continue;
                ++error;
            }
            System.out.format("Iris error rate = %.2f%%%n", 100.0 * (double)error / (double)x.length);
            Assert.assertEquals(8L, error);
        }
        catch (Exception ex) {
            System.err.println(ex);
        }
    }

    @Test
    public void testLearnMultinomial() {
        System.out.println("batch learn Multinomial");
        double[][] x = this.moviex;
        int[] y = this.moviey;
        int n = x.length;
        int k = 10;
        CrossValidation cv = new CrossValidation(n, k);
        int error = 0;
        int total = 0;
        for (int i = 0; i < k; ++i) {
            double[][] trainx = (double[][])Math.slice(x, cv.train[i]);
            int[] trainy = Math.slice(y, cv.train[i]);
            NaiveBayes bayes = new NaiveBayes(NaiveBayes.Model.MULTINOMIAL, 2, this.feature.length);
            bayes.learn(trainx, trainy);
            double[][] testx = (double[][])Math.slice(x, cv.test[i]);
            int[] testy = Math.slice(y, cv.test[i]);
            for (int j = 0; j < testx.length; ++j) {
                int label = bayes.predict(testx[j]);
                if (label == -1) continue;
                ++total;
                if (testy[j] == label) continue;
                ++error;
            }
        }
        System.out.format("Multinomial error = %d of %d%n", error, total);
        Assert.assertTrue(error < 265);
    }

    @Test
    public void testLearnMultinomial2() {
        System.out.println("online learn Multinomial");
        double[][] x = this.moviex;
        int[] y = this.moviey;
        int n = x.length;
        int k = 10;
        CrossValidation cv = new CrossValidation(n, k);
        int error = 0;
        int total = 0;
        for (int i = 0; i < k; ++i) {
            double[][] trainx = (double[][])Math.slice(x, cv.train[i]);
            int[] trainy = Math.slice(y, cv.train[i]);
            NaiveBayes bayes = new NaiveBayes(NaiveBayes.Model.MULTINOMIAL, 2, this.feature.length);
            for (int j = 0; j < trainx.length; ++j) {
                bayes.learn(trainx[j], trainy[j]);
            }
            double[][] testx = (double[][])Math.slice(x, cv.test[i]);
            int[] testy = Math.slice(y, cv.test[i]);
            for (int j = 0; j < testx.length; ++j) {
                int label = bayes.predict(testx[j]);
                if (label == -1) continue;
                ++total;
                if (testy[j] == label) continue;
                ++error;
            }
        }
        System.out.format("Multinomial error = %d of %d%n", error, total);
        Assert.assertTrue(error < 265);
    }

    @Test
    public void testLearnBernoulli() {
        System.out.println("batch learn Bernoulli");
        double[][] x = this.moviex;
        int[] y = this.moviey;
        int n = x.length;
        int k = 10;
        CrossValidation cv = new CrossValidation(n, k);
        int error = 0;
        int total = 0;
        for (int i = 0; i < k; ++i) {
            double[][] trainx = (double[][])Math.slice(x, cv.train[i]);
            int[] trainy = Math.slice(y, cv.train[i]);
            NaiveBayes bayes = new NaiveBayes(NaiveBayes.Model.BERNOULLI, 2, this.feature.length);
            bayes.learn(trainx, trainy);
            double[][] testx = (double[][])Math.slice(x, cv.test[i]);
            int[] testy = Math.slice(y, cv.test[i]);
            for (int j = 0; j < testx.length; ++j) {
                int label = bayes.predict(testx[j]);
                if (label == -1) continue;
                ++total;
                if (testy[j] == label) continue;
                ++error;
            }
        }
        System.out.format("Bernoulli error = %d of %d%n", error, total);
        Assert.assertTrue(error < 270);
    }

    @Test
    public void testLearnBernoulli2() {
        System.out.println("online learn Bernoulli");
        double[][] x = this.moviex;
        int[] y = this.moviey;
        int n = x.length;
        int k = 10;
        CrossValidation cv = new CrossValidation(n, k);
        int error = 0;
        int total = 0;
        for (int i = 0; i < k; ++i) {
            double[][] trainx = (double[][])Math.slice(x, cv.train[i]);
            int[] trainy = Math.slice(y, cv.train[i]);
            NaiveBayes bayes = new NaiveBayes(NaiveBayes.Model.BERNOULLI, 2, this.feature.length);
            for (int j = 0; j < trainx.length; ++j) {
                bayes.learn(trainx[j], trainy[j]);
            }
            double[][] testx = (double[][])Math.slice(x, cv.test[i]);
            int[] testy = Math.slice(y, cv.test[i]);
            for (int j = 0; j < testx.length; ++j) {
                int label = bayes.predict(testx[j]);
                if (label == -1) continue;
                ++total;
                if (testy[j] == label) continue;
                ++error;
            }
        }
        System.out.format("Bernoulli error = %d of %d%n", error, total);
        Assert.assertTrue(error < 270);
    }
}

