/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear.vectorcollection.lsh;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.RandomMatrix;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.CosineDistance;
import jsat.linear.distancemetrics.CosineDistanceNormalized;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.utils.BoundedSortedList;
import jsat.utils.IndexTable;
import jsat.utils.ProbailityMatch;
import jsat.utils.random.RandomUtil;

public class RandomProjectionLSH<V extends Vec>
implements VectorCollection<V> {
    private static final long serialVersionUID = -2042964665052386855L;
    private static final int NO_POOL = -1;
    private Matrix randProjMatrix;
    private int[] projections;
    private int slotsPerEntry;
    private List<V> vecs;
    private ThreadLocal<Vec> tempVecs;

    public RandomProjectionLSH(List<V> vecs, int ints, boolean inMemory) {
        this.randProjMatrix = new NormalMatrix(ints * 32, ((Vec)vecs.get(0)).length(), -1);
        if (inMemory) {
            DenseMatrix dense = new DenseMatrix(this.randProjMatrix.rows(), this.randProjMatrix.cols());
            dense.mutableAdd(this.randProjMatrix);
            this.randProjMatrix = dense;
        }
        this.build(true, vecs, new CosineDistance());
    }

    public RandomProjectionLSH(List<V> vecs, int ints, int poolSize) {
        this.randProjMatrix = new NormalMatrix(ints * 32, ((Vec)vecs.get(0)).length(), poolSize);
        this.build(true, vecs, new CosineDistance());
    }

    protected RandomProjectionLSH(RandomProjectionLSH<V> toCopy) {
        this.randProjMatrix = toCopy.randProjMatrix.clone();
        this.projections = Arrays.copyOf(toCopy.projections, toCopy.projections.length);
        this.slotsPerEntry = toCopy.slotsPerEntry;
        this.vecs = new ArrayList<V>(toCopy.vecs);
        this.tempVecs = new ThreadLocal<Vec>(){

            @Override
            protected Vec initialValue() {
                return new DenseVector(RandomProjectionLSH.this.randProjMatrix.rows());
            }
        };
    }

    @Override
    public List<Double> getAccelerationCache() {
        return null;
    }

    @Override
    public void build(boolean parallel, List<V> collection, DistanceMetric dm) {
        this.setDistanceMetric(dm);
        this.vecs = new ArrayList<V>(collection);
        this.tempVecs = ThreadLocal.withInitial(() -> new DenseVector(this.randProjMatrix.rows()));
        this.slotsPerEntry = this.randProjMatrix.rows() / 32;
        this.projections = new int[this.slotsPerEntry * this.vecs.size()];
        Vec projected = this.tempVecs.get();
        for (int slot = 0; slot < this.vecs.size(); ++slot) {
            projected.zeroOut();
            this.projectVector((Vec)this.vecs.get(slot), slot * this.slotsPerEntry, this.projections, projected);
        }
    }

    @Override
    public void search(Vec query, double range, List<Integer> neighbors, List<Double> distances) {
        int minHammingDist = (int)this.cosineToHamming(CosineDistance.distanceToCosine(range));
        int[] queryProj = new int[this.slotsPerEntry];
        Vec tmpSapce = this.tempVecs.get();
        tmpSapce.zeroOut();
        this.projectVector(query, 0, queryProj, tmpSapce);
        for (int slot = 0; slot < this.vecs.size(); ++slot) {
            int hamming = 0;
            int pos = 0;
            while (pos < this.slotsPerEntry) {
                hamming += Integer.bitCount(this.projections[slot * this.slotsPerEntry + pos] ^ queryProj[pos++]);
            }
            if (hamming > minHammingDist) continue;
            neighbors.add(slot);
            distances.add(CosineDistance.cosineToDistance(this.hammingToCosine(hamming)));
        }
        IndexTable it = new IndexTable(distances);
        it.apply(neighbors);
        it.apply(distances);
    }

    @Override
    public void search(Vec query, int numNeighbors, List<Integer> neighbors, List<Double> distances) {
        BoundedSortedList<ProbailityMatch<Integer>> toRet = new BoundedSortedList<ProbailityMatch<Integer>>(numNeighbors);
        int[] queryProj = new int[this.slotsPerEntry];
        Vec tmpSapce = this.tempVecs.get();
        tmpSapce.zeroOut();
        this.projectVector(query, 0, queryProj, tmpSapce);
        for (int slot = 0; slot < this.vecs.size(); ++slot) {
            int hamming = 0;
            int pos = 0;
            while (pos < this.slotsPerEntry) {
                hamming += Integer.bitCount(this.projections[slot * this.slotsPerEntry + pos] ^ queryProj[pos++]);
            }
            if (toRet.size() >= numNeighbors && !((double)hamming < ((ProbailityMatch)toRet.last()).getProbability())) continue;
            toRet.add(new ProbailityMatch<Integer>(hamming, slot));
        }
        for (int i = 0; i < toRet.size(); ++i) {
            neighbors.add((Integer)((ProbailityMatch)toRet.get(i)).getMatch());
            distances.add(CosineDistance.cosineToDistance(this.hammingToCosine(((ProbailityMatch)toRet.get(i)).getProbability())));
        }
    }

    public int getSignatureBitLength() {
        return this.randProjMatrix.rows() * 32;
    }

    private void projectVector(Vec vec, int slot, int[] projLocation, Vec projected) {
        this.randProjMatrix.multiply(vec, 1.0, projected);
        int bitsLeft = 32;
        int curVal = 0;
        for (int pos = 0; pos < this.slotsPerEntry; ++pos) {
            while (bitsLeft > 0) {
                curVal <<= 1;
                if (projected.get(pos * 32 + (32 - bitsLeft)) >= 0.0) {
                    curVal |= 1;
                }
                --bitsLeft;
            }
            projLocation[slot + pos] = curVal;
            curVal = 0;
            bitsLeft = 32;
        }
    }

    @Override
    public int size() {
        return this.vecs.size();
    }

    @Override
    public V get(int indx) {
        return (V)((Vec)this.vecs.get(indx));
    }

    @Override
    public VectorCollection<V> clone() {
        return new RandomProjectionLSH<V>(this);
    }

    private double hammingToCosine(double ham) {
        return Math.cos(ham * Math.PI / (double)this.randProjMatrix.rows());
    }

    private double cosineToHamming(double cos) {
        return (double)this.randProjMatrix.rows() * Math.acos(cos) / Math.PI;
    }

    @Override
    public void setDistanceMetric(DistanceMetric dm) {
        if (!(dm instanceof CosineDistance) && !(dm instanceof CosineDistanceNormalized)) {
            throw new IllegalArgumentException("RandomProjectionLSH is only compatible with the Cosine Distance metric");
        }
    }

    @Override
    public DistanceMetric getDistanceMetric() {
        return new CosineDistance();
    }

    private static final class NormalMatrix
    extends RandomMatrix {
        private static final long serialVersionUID = -5274754647385324984L;
        private final double[] pool;
        private final long seedMult;

        public NormalMatrix(int rows, int cols, int poolSize) {
            super(rows, cols);
            if (poolSize > 0) {
                this.pool = new double[poolSize];
                Random rand = RandomUtil.getRandom();
                for (int i = 0; i < this.pool.length; ++i) {
                    this.pool[i] = rand.nextGaussian();
                }
            } else {
                this.pool = null;
            }
            this.seedMult = RandomUtil.getRandom().nextLong();
        }

        public NormalMatrix(NormalMatrix toCopy) {
            super(toCopy);
            this.pool = (double[])(toCopy.pool == null ? null : Arrays.copyOf(toCopy.pool, toCopy.pool.length));
            this.seedMult = toCopy.seedMult;
        }

        @Override
        public double get(int i, int j) {
            if (this.pool == null) {
                return super.get(i, j);
            }
            long index = (long)((i + 1) * (j + this.cols())) * this.seedMult & Integer.MAX_VALUE;
            return this.pool[(int)index % this.pool.length];
        }

        @Override
        protected double getVal(Random rand) {
            if (this.pool == null) {
                return rand.nextGaussian();
            }
            return this.pool[rand.nextInt(this.pool.length)];
        }

        @Override
        public Matrix clone() {
            return new NormalMatrix(this);
        }
    }
}

