/*
 * Decompiled with CFR 0.152.
 */
package smile.neighbor;

import java.util.ArrayList;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.data.AttributeDataset;
import smile.data.Dataset;
import smile.data.NominalAttribute;
import smile.data.parser.DelimitedTextParser;
import smile.data.parser.IOUtils;
import smile.math.Math;
import smile.math.distance.EuclideanDistance;
import smile.neighbor.LinearSearch;
import smile.neighbor.MPLSH;
import smile.neighbor.Neighbor;

public class MPLSHTest {
    double[][] x = null;
    double[][] testx = null;
    MPLSH<double[]> lsh = null;
    LinearSearch<double[]> naive = null;

    public MPLSHTest() {
        Object train;
        DelimitedTextParser parser = new DelimitedTextParser();
        parser.setResponseIndex(new NominalAttribute("class"), 0);
        try {
            train = parser.parse("USPS Train", IOUtils.getTestDataFile("usps/zip.train"));
            AttributeDataset test = parser.parse("USPS Test", IOUtils.getTestDataFile("usps/zip.test"));
            this.x = (double[][])((Dataset)train).toArray((E[])new double[((Dataset)train).size()][]);
            this.testx = (double[][])test.toArray((E[])new double[test.size()][]);
        }
        catch (Exception ex) {
            System.err.println(ex);
        }
        this.naive = new LinearSearch<double[]>((T[])this.x, new EuclideanDistance());
        this.lsh = new MPLSH(256, 100, 3, 4.0);
        for (double[] xi : this.x) {
            this.lsh.put(xi, xi);
        }
        train = new double[500][];
        int[] index = Math.permutate(this.x.length);
        for (int i = 0; i < ((Object)train).length; ++i) {
            train[i] = this.x[index[i]];
        }
        this.lsh.learn(this.naive, (double[][])train, 8.0);
    }

    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testNearestPosteriori() {
        System.out.println("nearest posteriori");
        long time = System.currentTimeMillis();
        double recall = 0.0;
        double dist = 0.0;
        for (int i = 0; i < this.testx.length; ++i) {
            Neighbor<double[], double[]> neighbor = this.lsh.nearest(this.testx[i], 0.95, 50);
            dist += neighbor.distance;
            if (neighbor.index != this.naive.nearest((double[])this.testx[i]).index) continue;
            recall += 1.0;
        }
        System.out.println("recall is " + (recall /= (double)this.testx.length));
        System.out.println("average distance is " + dist / (double)this.testx.length);
        System.out.println("time is " + (double)(System.currentTimeMillis() - time) / 1000.0);
    }

    @Test
    public void testKnnPosteriori() {
        System.out.println("knn posteriori");
        long time = System.currentTimeMillis();
        double recall = 0.0;
        for (int i = 0; i < this.testx.length; ++i) {
            int k = 3;
            Neighbor<double[], double[]>[] n1 = this.lsh.knn(this.testx[i], k, 0.95, 50);
            Neighbor<double[], double[]>[] n2 = this.naive.knn(this.testx[i], k);
            int hit = 0;
            block1: for (int m = 0; m < k && n1[m] != null; ++m) {
                for (int n = 0; n < k && n2[n] != null; ++n) {
                    if (n1[m].index != n2[n].index) continue;
                    ++hit;
                    continue block1;
                }
            }
            recall += 1.0 * (double)hit / (double)k;
        }
        System.out.println("recall is " + (recall /= (double)this.testx.length));
        System.out.println("time is " + (double)(System.currentTimeMillis() - time) / 1000.0);
    }

    @Test
    public void testRangePosteriori() {
        System.out.println("range posteriori");
        long time = System.currentTimeMillis();
        double recall = 0.0;
        for (int i = 0; i < this.testx.length; ++i) {
            ArrayList n1 = new ArrayList();
            ArrayList n2 = new ArrayList();
            this.lsh.range(this.testx[i], 8.0, n1, 0.95, 50);
            this.naive.range(this.testx[i], 8.0, n2);
            int hit = 0;
            block1: for (int m = 0; m < n1.size(); ++m) {
                for (int n = 0; n < n2.size(); ++n) {
                    if (n1.get((int)m).index != n2.get((int)n).index) continue;
                    ++hit;
                    continue block1;
                }
            }
            if (n2.isEmpty()) continue;
            recall += 1.0 * (double)hit / (double)n2.size();
        }
        System.out.println("recall is " + (recall /= (double)this.testx.length));
        System.out.println("time is " + (double)(System.currentTimeMillis() - time) / 1000.0);
    }
}

