package jsat.linear.vectorcollection;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.ChebyshevDistance;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.ManhattanDistance;
import jsat.linear.distancemetrics.MinkowskiDistance;
import jsat.math.FastMath;
import jsat.math.OnLineStatistics;
import jsat.utils.BoundedSortedList;
import jsat.utils.ClosedHashingUtil;
import jsat.utils.DoubleList;
import jsat.utils.IndexTable;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;
import jsat.utils.ModifiableCountDownLatch;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:jsat/linear/vectorcollection/KDTree.class */
public class KDTree<V extends Vec> implements IncrementalCollection<V> {
    private static final long serialVersionUID = -7401342201406776463L;
    private DistanceMetric distanceMetric;
    private KDTree<V>.KDNode root;
    private PivotSelection pvSelection;
    private int size;
    private int leaf_node_size;
    private List<V> allVecs;
    private List<Double> distCache;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: jsat.linear.vectorcollection.KDTree$1, reason: invalid class name */
    /* loaded from: input_file:jsat/linear/vectorcollection/KDTree$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$jsat$linear$vectorcollection$KDTree$PivotSelection = new int[PivotSelection.values().length];

        static {
            try {
                $SwitchMap$jsat$linear$vectorcollection$KDTree$PivotSelection[PivotSelection.VARIANCE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$jsat$linear$vectorcollection$KDTree$PivotSelection[PivotSelection.SPREAD_MEDOID.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$jsat$linear$vectorcollection$KDTree$PivotSelection[PivotSelection.INCREMENTAL.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:jsat/linear/vectorcollection/KDTree$KDLeaf.class */
    public class KDLeaf extends KDTree<V>.KDNode {
        protected IntList owned;

        public KDLeaf(int i, List<Integer> list) {
            super(i);
            this.owned = new IntList(list);
        }

        public KDLeaf(KDTree<V>.KDLeaf kDLeaf) {
            super(KDTree.this, kDLeaf);
            this.owned = new IntList(kDLeaf.owned);
        }

        @Override // jsat.linear.vectorcollection.KDTree.KDNode
        protected void searchK(int i, BoundedSortedList<IndexDistPair> boundedSortedList, Vec vec, List<Double> list) {
            Iterator<Integer> it = this.owned.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                boundedSortedList.add((BoundedSortedList<IndexDistPair>) new IndexDistPair(intValue, KDTree.this.distanceMetric.dist(intValue, vec, list, KDTree.this.allVecs, KDTree.this.distCache)));
            }
        }

        @Override // jsat.linear.vectorcollection.KDTree.KDNode
        protected void searchR(double d, List<Integer> list, List<Double> list2, Vec vec, List<Double> list3) {
            Iterator<Integer> it = this.owned.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                double dist = KDTree.this.distanceMetric.dist(intValue, vec, list3, KDTree.this.allVecs, KDTree.this.distCache);
                if (dist <= d) {
                    list.add(Integer.valueOf(intValue));
                    list2.add(Double.valueOf(dist));
                }
            }
        }

        @Override // jsat.linear.vectorcollection.KDTree.KDNode
        protected boolean insert(int i) {
            this.owned.add(i);
            return this.owned.size() >= KDTree.this.leaf_node_size * 2;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // jsat.linear.vectorcollection.KDTree.KDNode
        /* renamed from: clone, reason: merged with bridge method [inline-methods] */
        public KDTree<V>.KDLeaf mo201clone() {
            return new KDLeaf(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:jsat/linear/vectorcollection/KDTree$KDNode.class */
    public class KDNode implements Cloneable, Serializable {
        protected int axis;
        protected double pivot_s;
        protected KDTree<V>.KDNode left;
        protected KDTree<V>.KDNode right;

        public KDNode(int i) {
            this.axis = i;
        }

        public KDNode(KDTree kDTree, KDTree<V>.KDNode kDNode) {
            this(kDNode.axis);
            this.pivot_s = kDNode.pivot_s;
            if (kDNode.left != null) {
                this.left = kDNode.left.mo201clone();
            }
            if (kDNode.left != null) {
                this.right = kDNode.right.mo201clone();
            }
        }

        public void setAxis(int i) {
            this.axis = i;
        }

        public void setLeft(KDTree<V>.KDNode kDNode) {
            this.left = kDNode;
        }

        public void setRight(KDTree<V>.KDNode kDNode) {
            this.right = kDNode;
        }

        public int getAxis() {
            return this.axis;
        }

        public KDTree<V>.KDNode getLeft() {
            return this.left;
        }

        public KDTree<V>.KDNode getRight() {
            return this.right;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // 
        /* renamed from: clone */
        public KDTree<V>.KDNode mo201clone() {
            return new KDNode(KDTree.this, this);
        }

        /* JADX WARN: Multi-variable type inference failed */
        protected void searchK(int i, BoundedSortedList<IndexDistPair> boundedSortedList, Vec vec, List<Double> list) {
            KDTree<V>.KDNode kDNode;
            KDTree<V>.KDNode kDNode2;
            double d = vec.get(this.axis);
            if (d <= this.pivot_s) {
                kDNode = this.left;
                kDNode2 = this.right;
            } else {
                kDNode = this.right;
                kDNode2 = this.left;
            }
            kDNode.searchK(i, boundedSortedList, vec, list);
            double d2 = Double.MAX_VALUE;
            if (boundedSortedList.size() >= i) {
                d2 = ((IndexDistPair) boundedSortedList.get(i - 1)).getDist();
            }
            if (d2 > Math.abs(d - this.pivot_s)) {
                kDNode2.searchK(i, boundedSortedList, vec, list);
            }
        }

        protected void searchR(double d, List<Integer> list, List<Double> list2, Vec vec, List<Double> list3) {
            double d2 = vec.get(this.axis);
            if (d > d2 - this.pivot_s) {
                this.left.searchR(d, list, list2, vec, list3);
            }
            if (d > this.pivot_s - d2) {
                this.right.searchR(d, list, list2, vec, list3);
            }
        }

        protected boolean insert(int i) {
            if (KDTree.this.get(i).get(this.axis) <= this.pivot_s) {
                if (!this.left.insert(i)) {
                    return false;
                }
                this.left = KDTree.this.buildTree(((KDLeaf) this.left).owned, this.axis + 1, null, null);
                return false;
            }
            if (!this.right.insert(i)) {
                return false;
            }
            this.right = KDTree.this.buildTree(((KDLeaf) this.right).owned, this.axis + 1, null, null);
            return false;
        }
    }

    /* loaded from: input_file:jsat/linear/vectorcollection/KDTree$PivotSelection.class */
    public enum PivotSelection {
        INCREMENTAL,
        VARIANCE,
        SPREAD_MEDOID
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:jsat/linear/vectorcollection/KDTree$VecIndexComparator.class */
    public class VecIndexComparator implements Comparator<Integer> {
        private final int index;

        public VecIndexComparator(int i) {
            this.index = i;
        }

        @Override // java.util.Comparator
        public int compare(Integer num, Integer num2) {
            return Double.compare(((Vec) KDTree.this.allVecs.get(num.intValue())).get(this.index), ((Vec) KDTree.this.allVecs.get(num2.intValue())).get(this.index));
        }
    }

    public KDTree(List<V> list, DistanceMetric distanceMetric, PivotSelection pivotSelection, boolean z) {
        this.leaf_node_size = 20;
        this.distanceMetric = distanceMetric;
        this.pvSelection = pivotSelection;
        build(z, list, distanceMetric);
    }

    public KDTree(List<V> list, DistanceMetric distanceMetric, PivotSelection pivotSelection) {
        this(list, distanceMetric, pivotSelection, false);
    }

    public KDTree(List<V> list, DistanceMetric distanceMetric) {
        this(list, distanceMetric, PivotSelection.SPREAD_MEDOID);
    }

    private KDTree(DistanceMetric distanceMetric, PivotSelection pivotSelection) {
        this.leaf_node_size = 20;
        setDistanceMetric(distanceMetric);
        this.pvSelection = pivotSelection;
    }

    public KDTree(PivotSelection pivotSelection) {
        this(new EuclideanDistance(), pivotSelection);
    }

    public KDTree() {
        this(PivotSelection.SPREAD_MEDOID);
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public List<Double> getAccelerationCache() {
        return this.distCache;
    }

    public void setLeafSize(int i) {
        if (i < 2) {
            throw new IllegalArgumentException("The leaf size must be >= 2 to support all splitting methods");
        }
        this.leaf_node_size = i;
    }

    public int getLeafSize() {
        return this.leaf_node_size;
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public void setDistanceMetric(DistanceMetric distanceMetric) {
        if (!(distanceMetric instanceof EuclideanDistance) && !(distanceMetric instanceof ChebyshevDistance) && !(distanceMetric instanceof ManhattanDistance) && !(distanceMetric instanceof MinkowskiDistance)) {
            throw new ArithmeticException("KD Trees are not compatible with the given distance metric.");
        }
        this.distanceMetric = distanceMetric;
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public DistanceMetric getDistanceMetric() {
        return this.distanceMetric;
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public void build(boolean z, List<V> list, DistanceMetric distanceMetric) {
        setDistanceMetric(distanceMetric);
        this.size = list.size();
        ArrayList arrayList = new ArrayList(list);
        this.allVecs = arrayList;
        this.distCache = this.distanceMetric.getAccelerationCache(arrayList, z);
        IntList intList = new IntList(this.size);
        ListUtils.addRange(intList, 0, this.size, 1);
        if (!z) {
            this.root = buildTree(intList, 0, null, null);
            return;
        }
        ModifiableCountDownLatch modifiableCountDownLatch = new ModifiableCountDownLatch(1);
        this.root = buildTree(intList, 0, ParallelUtils.CACHED_THREAD_POOL, modifiableCountDownLatch);
        try {
            modifiableCountDownLatch.await();
        } catch (InterruptedException e) {
            this.root = buildTree(intList, 0, null, null);
        }
    }

    @Override // jsat.linear.vectorcollection.IncrementalCollection
    public void insert(V v) {
        if (this.allVecs == null) {
            this.allVecs = new ArrayList();
            this.distCache = this.distanceMetric.getAccelerationCache(this.allVecs);
            this.size = 0;
            this.root = new KDLeaf(0, new IntList());
        }
        int i = this.size;
        this.size = i + 1;
        this.allVecs.add(v);
        if (this.distCache != null) {
            this.distCache.addAll(this.distanceMetric.getQueryInfo(v));
        }
        if (this.root.insert(i)) {
            this.root = buildTree(IntList.range(this.size), 0, null, null);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public KDTree<V>.KDNode buildTree(List<Integer> list, int i, ExecutorService executorService, ModifiableCountDownLatch modifiableCountDownLatch) {
        if (list == null || list.isEmpty()) {
            if (executorService == null) {
                return null;
            }
            modifiableCountDownLatch.countDown();
            return null;
        }
        int length = this.allVecs.get(0).length();
        if (list.size() <= this.leaf_node_size) {
            if (executorService != null) {
                modifiableCountDownLatch.countDown();
            }
            return new KDLeaf(i % length, list);
        }
        boolean isSparse = get(list.get(0).intValue()).isSparse();
        int i2 = -1;
        double d = Double.NaN;
        switch (AnonymousClass1.$SwitchMap$jsat$linear$vectorcollection$KDTree$PivotSelection[this.pvSelection.ordinal()]) {
            case 1:
                OnLineStatistics[] onLineStatisticsArr = new OnLineStatistics[length];
                for (int i3 = 0; i3 < onLineStatisticsArr.length; i3++) {
                    onLineStatisticsArr[i3] = new OnLineStatistics();
                }
                Iterator<Integer> it = list.iterator();
                while (it.hasNext()) {
                    V v = get(it.next().intValue());
                    for (int i4 = 0; i4 < onLineStatisticsArr.length; i4++) {
                        onLineStatisticsArr[i4].add(v.get(i4));
                    }
                }
                double d2 = -1.0d;
                for (int i5 = 0; i5 < onLineStatisticsArr.length; i5++) {
                    if (onLineStatisticsArr[i5].getVarance() > d2) {
                        d2 = onLineStatisticsArr[i5].getVarance();
                        i2 = i5;
                    }
                }
                if (i2 < 0) {
                    i2 = i % length;
                    break;
                }
                break;
            case ClosedHashingUtil.DELETED /* 2 */:
                double[] dArr = new double[length];
                double[] dArr2 = new double[length];
                Arrays.fill(dArr, Double.POSITIVE_INFINITY);
                Arrays.fill(dArr2, Double.NEGATIVE_INFINITY);
                Set intSet = isSparse ? new IntSet(ListUtils.range(0, get(0).length())) : Collections.EMPTY_SET;
                Iterator<Integer> it2 = list.iterator();
                while (it2.hasNext()) {
                    Iterator<IndexValue> it3 = get(it2.next().intValue()).iterator();
                    while (it3.hasNext()) {
                        IndexValue next = it3.next();
                        int index = next.getIndex();
                        double value = next.getValue();
                        dArr[index] = Math.min(dArr[index], value);
                        dArr2[index] = Math.max(dArr2[index], value);
                        intSet.remove(Integer.valueOf(index));
                    }
                }
                int i6 = 0;
                double d3 = 0.0d;
                for (int i7 = 0; i7 < length; i7++) {
                    if (intSet != null && intSet.contains(Integer.valueOf(i7))) {
                        dArr2[i7] = Math.max(dArr2[i7], 0.0d);
                        dArr[i7] = Math.min(dArr[i7], 0.0d);
                    }
                    double d4 = dArr2[i7] - dArr[i7];
                    if (d4 > d3) {
                        i6 = i7;
                        d3 = d4;
                    }
                }
                i2 = i6;
                double d5 = ((dArr2[i6] - dArr[i6]) / 2.0d) + dArr[i6];
                double d6 = dArr2[i6];
                for (int i8 = 0; i8 < list.size(); i8++) {
                    double d7 = get(i8).get(i6);
                    if (Math.abs(d5 - d7) < Math.abs(d5 - d6)) {
                        d6 = d7;
                    }
                }
                d = d6;
                break;
            case 3:
            default:
                i2 = i % length;
                break;
        }
        KDTree<V>.KDNode kDNode = new KDNode(i2);
        int i9 = -1;
        if (!Double.isNaN(d)) {
            int i10 = 0;
            for (int i11 = 0; i11 < list.size(); i11++) {
                if (get(list.get(i11).intValue()).get(i2) <= d) {
                    int i12 = i10;
                    i10++;
                    ListUtils.swap(list, i12, i11);
                }
            }
            int floor_log2 = FastMath.floor_log2(this.allVecs.size());
            if (((floor_log2 * 3) / 2 >= i || (i10 >= this.leaf_node_size / 3 && list.size() - i10 >= this.leaf_node_size / 3)) && floor_log2 * 3 >= i) {
                i9 = i10 - 1;
                kDNode.pivot_s = d;
            } else {
                d = Double.NaN;
            }
        }
        if (i9 <= 0 || i9 >= list.size() - 1) {
            d = Double.NaN;
        }
        if (Double.isNaN(d)) {
            Collections.sort(list, new VecIndexComparator(i2));
            i9 = getMedianIndex(list, i2);
            if (i9 == list.size() - 1) {
                return new KDLeaf(i % length, list);
            }
            kDNode.pivot_s = get(list.get(i9).intValue()).get(i2);
        }
        if (i9 == 0 || i9 >= list.size() - 1) {
            System.out.println("Adsas");
        }
        if (executorService == null) {
            kDNode.setLeft(buildTree(list.subList(0, i9 + 1), i + 1, executorService, modifiableCountDownLatch));
            kDNode.setRight(buildTree(list.subList(i9 + 1, list.size()), i + 1, executorService, modifiableCountDownLatch));
        } else {
            modifiableCountDownLatch.countUp();
            IntList intList = new IntList(list.subList(0, i9 + 1));
            IntList intList2 = new IntList(list.subList(i9 + 1, list.size()));
            executorService.submit(() -> {
                kDNode.setRight(buildTree(intList2, i + 1, executorService, modifiableCountDownLatch));
            });
            kDNode.setLeft(buildTree(intList, i + 1, executorService, modifiableCountDownLatch));
        }
        return kDNode;
    }

    public int getMedianIndex(List<Integer> list, int i) {
        int size = list.size() / 2;
        while (size < list.size() - 1 && this.allVecs.get(list.get(size).intValue()).get(i) == this.allVecs.get(list.get(size + 1).intValue()).get(i)) {
            size++;
        }
        return size;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // jsat.linear.vectorcollection.VectorCollection, jsat.linear.vectorcollection.DualTree
    public void search(Vec vec, int i, List<Integer> list, List<Double> list2) {
        if (i < 1) {
            throw new RuntimeException("Invalid number of neighbors to search for");
        }
        BoundedSortedList<IndexDistPair> boundedSortedList = new BoundedSortedList<>(i);
        this.root.searchK(i, boundedSortedList, vec, this.distanceMetric.getQueryInfo(vec));
        list.clear();
        list2.clear();
        for (int i2 = 0; i2 < boundedSortedList.size(); i2++) {
            IndexDistPair indexDistPair = (IndexDistPair) boundedSortedList.get(i2);
            list.add(Integer.valueOf(indexDistPair.getIndex()));
            list2.add(Double.valueOf(indexDistPair.getDist()));
        }
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public int size() {
        return this.size;
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public V get(int i) {
        return this.allVecs.get(i);
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public void search(Vec vec, double d, List<Integer> list, List<Double> list2) {
        if (d <= 0.0d) {
            throw new RuntimeException("Range must be a positive number");
        }
        list.clear();
        list2.clear();
        this.root.searchR(d, list, list2, vec, this.distanceMetric.getQueryInfo(vec));
        IndexTable indexTable = new IndexTable(list2);
        indexTable.apply(list);
        indexTable.apply(list2);
    }

    @Override // jsat.linear.vectorcollection.IncrementalCollection, jsat.linear.vectorcollection.VectorCollection, jsat.linear.vectorcollection.DualTree
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public KDTree<V> m199clone() {
        KDTree<V> kDTree = new KDTree<>(this.distanceMetric, this.pvSelection);
        if (this.distCache != null) {
            kDTree.distCache = new DoubleList(this.distCache);
        }
        if (this.allVecs != null) {
            kDTree.allVecs = new ArrayList(this.allVecs);
        }
        kDTree.size = this.size;
        if (this.root != null) {
            kDTree.root = this.root.mo201clone();
        }
        return kDTree;
    }
}
