package jsat.linear.vectorcollection.lsh;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.RandomMatrix;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.CosineDistance;
import jsat.linear.distancemetrics.CosineDistanceNormalized;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.utils.BoundedSortedList;
import jsat.utils.IndexTable;
import jsat.utils.ProbailityMatch;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/linear/vectorcollection/lsh/RandomProjectionLSH.class */
public class RandomProjectionLSH<V extends Vec> implements VectorCollection<V> {
    private static final long serialVersionUID = -2042964665052386855L;
    private static final int NO_POOL = -1;
    private Matrix randProjMatrix;
    private int[] projections;
    private int slotsPerEntry;
    private List<V> vecs;
    private ThreadLocal<Vec> tempVecs;

    /* loaded from: input_file:jsat/linear/vectorcollection/lsh/RandomProjectionLSH$NormalMatrix.class */
    private static final class NormalMatrix extends RandomMatrix {
        private static final long serialVersionUID = -5274754647385324984L;
        private final double[] pool;
        private final long seedMult;

        public NormalMatrix(int i, int i2, int i3) {
            super(i, i2);
            if (i3 > 0) {
                this.pool = new double[i3];
                Random random = RandomUtil.getRandom();
                for (int i4 = 0; i4 < this.pool.length; i4++) {
                    this.pool[i4] = random.nextGaussian();
                }
            } else {
                this.pool = null;
            }
            this.seedMult = RandomUtil.getRandom().nextLong();
        }

        public NormalMatrix(NormalMatrix normalMatrix) {
            super(normalMatrix);
            if (normalMatrix.pool == null) {
                this.pool = null;
            } else {
                this.pool = Arrays.copyOf(normalMatrix.pool, normalMatrix.pool.length);
            }
            this.seedMult = normalMatrix.seedMult;
        }

        @Override // jsat.linear.RandomMatrix, jsat.linear.Matrix
        public double get(int i, int i2) {
            if (this.pool == null) {
                return super.get(i, i2);
            }
            return this.pool[((int) ((((i + 1) * (i2 + cols())) * this.seedMult) & 2147483647L)) % this.pool.length];
        }

        @Override // jsat.linear.RandomMatrix
        protected double getVal(Random random) {
            return this.pool == null ? random.nextGaussian() : this.pool[random.nextInt(this.pool.length)];
        }

        @Override // jsat.linear.GenericMatrix, jsat.linear.Matrix
        /* renamed from: clone */
        public Matrix mo171clone() {
            return new NormalMatrix(this);
        }
    }

    public RandomProjectionLSH(List<V> list, int i, boolean z) {
        this.randProjMatrix = new NormalMatrix(i * 32, list.get(0).length(), -1);
        if (z) {
            DenseMatrix denseMatrix = new DenseMatrix(this.randProjMatrix.rows(), this.randProjMatrix.cols());
            denseMatrix.mutableAdd(this.randProjMatrix);
            this.randProjMatrix = denseMatrix;
        }
        build(true, list, new CosineDistance());
    }

    public RandomProjectionLSH(List<V> list, int i, int i2) {
        this.randProjMatrix = new NormalMatrix(i * 32, list.get(0).length(), i2);
        build(true, list, new CosineDistance());
    }

    protected RandomProjectionLSH(RandomProjectionLSH<V> randomProjectionLSH) {
        this.randProjMatrix = randomProjectionLSH.randProjMatrix.mo171clone();
        this.projections = Arrays.copyOf(randomProjectionLSH.projections, randomProjectionLSH.projections.length);
        this.slotsPerEntry = randomProjectionLSH.slotsPerEntry;
        this.vecs = new ArrayList(randomProjectionLSH.vecs);
        this.tempVecs = new ThreadLocal<Vec>() { // from class: jsat.linear.vectorcollection.lsh.RandomProjectionLSH.1
            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.lang.ThreadLocal
            public Vec initialValue() {
                return new DenseVector(RandomProjectionLSH.this.randProjMatrix.rows());
            }
        };
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public List<Double> getAccelerationCache() {
        return null;
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public void build(boolean z, List<V> list, DistanceMetric distanceMetric) {
        setDistanceMetric(distanceMetric);
        this.vecs = new ArrayList(list);
        this.tempVecs = ThreadLocal.withInitial(() -> {
            return new DenseVector(this.randProjMatrix.rows());
        });
        this.slotsPerEntry = this.randProjMatrix.rows() / 32;
        this.projections = new int[this.slotsPerEntry * this.vecs.size()];
        Vec vec = this.tempVecs.get();
        for (int i = 0; i < this.vecs.size(); i++) {
            vec.zeroOut();
            projectVector(this.vecs.get(i), i * this.slotsPerEntry, this.projections, vec);
        }
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public void search(Vec vec, double d, List<Integer> list, List<Double> list2) {
        int cosineToHamming = (int) cosineToHamming(CosineDistance.distanceToCosine(d));
        int[] iArr = new int[this.slotsPerEntry];
        Vec vec2 = this.tempVecs.get();
        vec2.zeroOut();
        projectVector(vec, 0, iArr, vec2);
        for (int i = 0; i < this.vecs.size(); i++) {
            int i2 = 0;
            int i3 = 0;
            while (i3 < this.slotsPerEntry) {
                int i4 = this.projections[(i * this.slotsPerEntry) + i3];
                int i5 = i3;
                i3++;
                i2 += Integer.bitCount(i4 ^ iArr[i5]);
            }
            if (i2 <= cosineToHamming) {
                list.add(Integer.valueOf(i));
                list2.add(Double.valueOf(CosineDistance.cosineToDistance(hammingToCosine(i2))));
            }
        }
        IndexTable indexTable = new IndexTable(list2);
        indexTable.apply(list);
        indexTable.apply(list2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // jsat.linear.vectorcollection.VectorCollection, jsat.linear.vectorcollection.DualTree
    public void search(Vec vec, int i, List<Integer> list, List<Double> list2) {
        BoundedSortedList boundedSortedList = new BoundedSortedList(i);
        int[] iArr = new int[this.slotsPerEntry];
        Vec vec2 = this.tempVecs.get();
        vec2.zeroOut();
        projectVector(vec, 0, iArr, vec2);
        for (int i2 = 0; i2 < this.vecs.size(); i2++) {
            int i3 = 0;
            int i4 = 0;
            while (i4 < this.slotsPerEntry) {
                int i5 = this.projections[(i2 * this.slotsPerEntry) + i4];
                int i6 = i4;
                i4++;
                i3 += Integer.bitCount(i5 ^ iArr[i6]);
            }
            if (boundedSortedList.size() < i || i3 < ((ProbailityMatch) boundedSortedList.last()).getProbability()) {
                boundedSortedList.add((BoundedSortedList) new ProbailityMatch(i3, Integer.valueOf(i2)));
            }
        }
        for (int i7 = 0; i7 < boundedSortedList.size(); i7++) {
            list.add(((ProbailityMatch) boundedSortedList.get(i7)).getMatch());
            list2.add(Double.valueOf(CosineDistance.cosineToDistance(hammingToCosine(((ProbailityMatch) boundedSortedList.get(i7)).getProbability()))));
        }
    }

    public int getSignatureBitLength() {
        return this.randProjMatrix.rows() * 32;
    }

    private void projectVector(Vec vec, int i, int[] iArr, Vec vec2) {
        this.randProjMatrix.multiply(vec, 1.0d, vec2);
        int i2 = 32;
        int i3 = 0;
        for (int i4 = 0; i4 < this.slotsPerEntry; i4++) {
            while (i2 > 0) {
                i3 <<= 1;
                if (vec2.get((i4 * 32) + (32 - i2)) >= 0.0d) {
                    i3 |= 1;
                }
                i2--;
            }
            iArr[i + i4] = i3;
            i3 = 0;
            i2 = 32;
        }
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public int size() {
        return this.vecs.size();
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public V get(int i) {
        return this.vecs.get(i);
    }

    @Override // jsat.linear.vectorcollection.VectorCollection, jsat.linear.vectorcollection.DualTree
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public VectorCollection<V> m214clone() {
        return new RandomProjectionLSH(this);
    }

    private double hammingToCosine(double d) {
        return Math.cos((d * 3.141592653589793d) / this.randProjMatrix.rows());
    }

    private double cosineToHamming(double d) {
        return (this.randProjMatrix.rows() * Math.acos(d)) / 3.141592653589793d;
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public void setDistanceMetric(DistanceMetric distanceMetric) {
        if (!(distanceMetric instanceof CosineDistance) && !(distanceMetric instanceof CosineDistanceNormalized)) {
            throw new IllegalArgumentException("RandomProjectionLSH is only compatible with the Cosine Distance metric");
        }
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public DistanceMetric getDistanceMetric() {
        return new CosineDistance();
    }
}
