/*
 * Decompiled with CFR 0.152.
 */
package smile.neighbor;

import java.io.IOException;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import smile.data.parser.IOUtils;
import smile.math.distance.HammingDistance;
import smile.neighbor.Neighbor;
import smile.neighbor.SNLSH;
import smile.sort.HeapSelect;

public class SNLSHTest {
    private String[] texts = new String[]{"This is a test case", "This is another test case", "This is another test case too", "I want to be far from other cases"};
    private List<Sentence> testData;
    private List<Sentence> trainData;
    private List<Sentence> toyData;
    private Map<String, Long> signCache;

    @Before
    public void before() throws IOException {
        this.trainData = this.loadData("msrp/msr_paraphrase_train.txt");
        this.testData = this.loadData("msrp/msr_paraphrase_test.txt");
        this.signCache = new HashMap<String, Long>();
        for (Sentence sentence : this.trainData) {
            long sign = SNLSH.simhash64(sentence.tokens);
            this.signCache.put(sentence.line, sign);
        }
        this.toyData = new ArrayList<Sentence>();
        for (String text : this.texts) {
            this.toyData.add(new Sentence(text));
        }
    }

    private List<Sentence> loadData(String path) throws IOException {
        ArrayList<Sentence> data = new ArrayList<Sentence>();
        List<String> lines = IOUtils.readLines(IOUtils.getTestDataReader(path));
        for (String line : lines) {
            List<String> s = this.tokenize(line, "\t");
            data.add(new Sentence(s.get(s.size() - 1)));
            data.add(new Sentence(s.get(s.size() - 2)));
        }
        return data.subList(2, data.size());
    }

    private Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>[] linearKNN(SNLSH.AbstractSentence q, int k) {
        Comparable[] neighbors = (Neighbor[])Array.newInstance(Neighbor.class, k);
        HeapSelect heap = new HeapSelect(neighbors);
        Neighbor<Object, Object> neighbor = new Neighbor<Object, Object>(null, null, 0, Double.MAX_VALUE);
        for (int i = 0; i < k; ++i) {
            heap.add(neighbor);
        }
        long sign1 = SNLSH.simhash64(q.tokens);
        int hit = 0;
        for (Sentence sentence : this.trainData) {
            long sign2;
            double distance;
            if (sentence.line.equals(q.line) || !((distance = (double)HammingDistance.d(sign1, sign2 = this.signCache.get(sentence.line).longValue())) < ((Neighbor)heap.peek()).distance)) continue;
            heap.add(new Neighbor<Sentence, Sentence>(sentence, sentence, 0, distance));
            ++hit;
        }
        heap.sort();
        if (hit < k) {
            Neighbor[] n2 = (Neighbor[])Array.newInstance(Neighbor.class, hit);
            int start = k - hit;
            for (int i = 0; i < hit; ++i) {
                n2[i] = neighbors[i + start];
            }
            neighbors = n2;
        }
        return neighbors;
    }

    private Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence> linearNearest(SNLSH.AbstractSentence q) {
        long sign1 = SNLSH.simhash64(q.tokens);
        double minDist = Double.MAX_VALUE;
        Sentence minKey = null;
        for (Sentence sentence : this.trainData) {
            long sign2;
            double distance;
            if (sentence.line.equals(q.line) || !((distance = (double)HammingDistance.d(sign1, sign2 = this.signCache.get(sentence.line).longValue())) < minDist)) continue;
            minDist = distance;
            minKey = sentence;
        }
        return new Neighbor<Object, Object>(minKey, minKey, 0, minDist);
    }

    private void linearRange(Sentence q, double d, List<Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>> neighbors) {
        long sign1 = SNLSH.simhash64(q.tokens);
        for (Sentence sentence : this.trainData) {
            long sign2;
            double distance;
            if (sentence.line.equals(q.line) || !((distance = (double)HammingDistance.d(sign1, sign2 = this.signCache.get(sentence.line).longValue())) <= d)) continue;
            neighbors.add(new Neighbor<Sentence, Sentence>(sentence, sentence, 0, distance));
        }
    }

    @Test
    public void testKNN() {
        SNLSH<SNLSH.AbstractSentence> lsh = this.createLSH(this.toyData);
        Sentence sentence = new Sentence(this.texts[0]);
        Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>[] ns = lsh.knn(sentence, 10);
        System.out.println("-----test knn: ------");
        for (int i = 0; i < ns.length; ++i) {
            System.out.println("neighbor" + i + " : " + ((SNLSH.AbstractSentence)ns[i].key).line + ". distance: " + ns[i].distance);
        }
        System.out.println("------test knn end------");
    }

    @Test
    public void testKNNRecall() {
        SNLSH<SNLSH.AbstractSentence> lsh = this.createLSH(this.trainData);
        double recall = 0.0;
        for (SNLSH.AbstractSentence abstractSentence : this.testData) {
            int k = 3;
            Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>[] n1 = lsh.knn(abstractSentence, k);
            Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>[] n2 = this.linearKNN(abstractSentence, k);
            int hit = 0;
            block1: for (int m = 0; m < n1.length && n1[m] != null; ++m) {
                for (int n = 0; n < n2.length && n2[n] != null; ++n) {
                    if (!((SNLSH.AbstractSentence)n1[m].value).equals(n2[n].value)) continue;
                    ++hit;
                    continue block1;
                }
            }
            recall += 1.0 * (double)hit / (double)k;
        }
        System.out.println("SNLSH KNN recall is " + (recall /= (double)this.testData.size()));
    }

    @Test
    public void testNearest() {
        SNLSH<SNLSH.AbstractSentence> lsh = this.createLSH(this.toyData);
        System.out.println("----------test nearest start:-------");
        Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence> n = lsh.nearest(new Sentence(this.texts[0]));
        System.out.println("neighbor : " + ((SNLSH.AbstractSentence)n.key).line + " distance: " + n.distance);
        System.out.println("----------test nearest end-------");
    }

    @Test
    public void testNearestRecall() {
        SNLSH<SNLSH.AbstractSentence> lsh = this.createLSH(this.trainData);
        double recall = 0.0;
        for (SNLSH.AbstractSentence abstractSentence : this.testData) {
            Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence> n1 = lsh.nearest(abstractSentence);
            Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence> n2 = this.linearNearest(abstractSentence);
            if (!((SNLSH.AbstractSentence)n1.value).equals(n2.value)) continue;
            recall += 1.0;
        }
        System.out.println("SNLSH Nearest recall is " + (recall /= (double)this.testData.size()));
    }

    @Test
    public void testRange() {
        SNLSH<SNLSH.AbstractSentence> lsh = this.createLSH(this.toyData);
        ArrayList ns = new ArrayList();
        lsh.range(new Sentence(this.texts[0]), 10.0, (List<Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>>)ns);
        System.out.println("-------test range begin-------");
        for (Neighbor neighbor : ns) {
            System.out.println(((SNLSH.AbstractSentence)neighbor.key).line + "  distance: " + neighbor.distance);
        }
        System.out.println("-----test range end ----------");
    }

    @Test
    public void testRangeRecall() {
        SNLSH<SNLSH.AbstractSentence> lsh = this.createLSH(this.trainData);
        double dist = 15.0;
        double recall = 0.0;
        for (Sentence q : this.testData) {
            ArrayList n1 = new ArrayList();
            lsh.range(q, dist, (List<Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>>)n1);
            ArrayList<Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>> n2 = new ArrayList<Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>>();
            this.linearRange(q, dist, n2);
            int hit = 0;
            block1: for (int m = 0; m < n1.size(); ++m) {
                for (int n = 0; n < n2.size(); ++n) {
                    if (!((SNLSH.AbstractSentence)((Neighbor)n1.get((int)m)).value).equals(((Neighbor)n2.get((int)n)).value)) continue;
                    ++hit;
                    continue block1;
                }
            }
            if (n2.isEmpty()) continue;
            recall += 1.0 * (double)hit / (double)n2.size();
        }
        System.out.println("SNLSH range recall is " + (recall /= (double)this.testData.size()));
    }

    private SNLSH<SNLSH.AbstractSentence> createLSH(List<Sentence> data) {
        SNLSH<SNLSH.AbstractSentence> lsh = new SNLSH<SNLSH.AbstractSentence>(8);
        for (Sentence sentence : data) {
            lsh.put(sentence, sentence);
        }
        return lsh;
    }

    private List<String> tokenize(String line, String regex) {
        String[] ss;
        LinkedList<String> tokens = new LinkedList<String>();
        if (line == null || line.isEmpty()) {
            throw new IllegalArgumentException("Line should not be blank!");
        }
        for (String s : ss = line.split(regex)) {
            if (s == null || s.isEmpty()) continue;
            tokens.add(s);
        }
        return tokens;
    }

    private class Sentence
    extends SNLSH.AbstractSentence {
        public Sentence(String line) {
            this.line = line;
            this.tokens = this.tokenize(line);
        }

        @Override
        List<String> tokenize(String line) {
            return this.tokenize(line, " ");
        }

        private List<String> tokenize(String line, String regex) {
            String[] ss;
            LinkedList<String> tokens = new LinkedList<String>();
            if (line == null || line.isEmpty()) {
                throw new IllegalArgumentException("Line should not be blank!");
            }
            for (String s : ss = line.split(regex)) {
                if (s == null || s.isEmpty()) continue;
                tokens.add(s);
            }
            return tokens;
        }
    }
}

