package jsat.linear.vectorcollection;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.utils.BoundedSortedList;
import jsat.utils.DoubleList;
import jsat.utils.IndexTable;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

/* loaded from: input_file:jsat/linear/vectorcollection/DualTree.class */
public interface DualTree<V extends Vec> extends VectorCollection<V> {
    public static final double COMP_SCORE = -1.0d;

    /* loaded from: input_file:jsat/linear/vectorcollection/DualTree$DualTreeTraversalAction.class */
    public static class DualTreeTraversalAction extends RecursiveAction implements Comparable<DualTreeTraversalAction> {
        IndexNode n_r;
        IndexNode n_q;
        BaseCaseDT base;
        ScoreDT score;
        boolean improvedSearch;
        double priority;

        public DualTreeTraversalAction(IndexNode indexNode, IndexNode indexNode2, BaseCaseDT baseCaseDT, ScoreDT scoreDT, boolean z) {
            this(indexNode, indexNode2, baseCaseDT, scoreDT, z, 0.0d);
        }

        public DualTreeTraversalAction(IndexNode indexNode, IndexNode indexNode2, BaseCaseDT baseCaseDT, ScoreDT scoreDT, boolean z, double d) {
            this.n_r = indexNode;
            this.n_q = indexNode2;
            this.base = baseCaseDT;
            this.score = scoreDT;
            this.improvedSearch = z;
            this.priority = d;
        }

        @Override // java.util.concurrent.RecursiveAction
        protected void compute() {
            if ((this.score instanceof ScoreDTLazy) && Double.isNaN(this.score.score(this.n_r, this.n_q, this.priority))) {
                return;
            }
            for (int i = 0; i < this.n_r.numPoints(); i++) {
                for (int i2 = 0; i2 < this.n_q.numPoints(); i2++) {
                    this.base.base_case(this.n_r.getPoint(i), this.n_q.getPoint(i2));
                }
            }
            PriorityQueue priorityQueue = new PriorityQueue();
            if (this.n_q.hasChildren() && this.n_r.hasChildren()) {
                if (this.improvedSearch) {
                    for (int i3 = 0; i3 < this.n_q.numChildren(); i3++) {
                        IndexNode child = this.n_q.getChild(i3);
                        ArrayList arrayList = new ArrayList();
                        boolean z = true;
                        for (int i4 = 0; i4 < this.n_r.numChildren(); i4++) {
                            IndexNode child2 = this.n_r.getChild(i4);
                            double score = this.score.score(child2, child, -1.0d);
                            if (i4 > 0 && Math.abs(((DualTreeTraversalAction) arrayList.get(i4 - 1)).priority - score) < 1.0E-13d) {
                                z = false;
                            }
                            arrayList.add(new DualTreeTraversalAction(child2, child, this.base, this.score, this.improvedSearch, score));
                        }
                        if (z) {
                            priorityQueue.offer(new DualTreeTraversalAction(this.n_r, child, this.base, this.score, this.improvedSearch, this.score.score(this.n_r, child, -1.0d)));
                        } else {
                            priorityQueue.addAll(arrayList);
                        }
                    }
                } else {
                    for (int i5 = 0; i5 < this.n_r.numChildren(); i5++) {
                        for (int i6 = 0; i6 < this.n_q.numChildren(); i6++) {
                            IndexNode child3 = this.n_r.getChild(i5);
                            IndexNode child4 = this.n_q.getChild(i6);
                            double score2 = this.score.score(child3, child4, -1.0d);
                            if (!Double.isNaN(score2)) {
                                priorityQueue.offer(new DualTreeTraversalAction(child3, child4, this.base, this.score, this.improvedSearch, score2));
                            }
                        }
                    }
                }
            } else if (this.n_q.hasChildren()) {
                for (int i7 = 0; i7 < this.n_q.numChildren(); i7++) {
                    IndexNode child5 = this.n_q.getChild(i7);
                    double score3 = this.score.score(this.n_r, child5, -1.0d);
                    if (!Double.isNaN(score3)) {
                        priorityQueue.offer(new DualTreeTraversalAction(this.n_r, child5, this.base, this.score, this.improvedSearch, score3));
                    }
                }
            } else if (this.n_r.hasChildren()) {
                for (int i8 = 0; i8 < this.n_r.numChildren(); i8++) {
                    IndexNode child6 = this.n_r.getChild(i8);
                    double score4 = this.score.score(child6, this.n_q, -1.0d);
                    if (!Double.isNaN(score4)) {
                        priorityQueue.offer(new DualTreeTraversalAction(child6, this.n_q, this.base, this.score, this.improvedSearch, score4));
                    }
                }
            }
            invokeAll(priorityQueue);
        }

        @Override // java.lang.Comparable
        public int compareTo(DualTreeTraversalAction dualTreeTraversalAction) {
            return Double.compare(this.priority, dualTreeTraversalAction.priority);
        }
    }

    /* loaded from: input_file:jsat/linear/vectorcollection/DualTree$SelfAsChildNode.class */
    public static class SelfAsChildNode<N extends IndexNode<N>> implements IndexNode<SelfAsChildNode<N>> {
        public boolean asLeaf;
        N wrapping;

        public SelfAsChildNode(N n) {
            this.wrapping = n;
            this.asLeaf = !n.hasChildren();
        }

        public SelfAsChildNode(boolean z, N n) {
            this.asLeaf = z;
            this.wrapping = n;
        }

        @Override // jsat.linear.vectorcollection.IndexNode
        public double furthestPointDistance() {
            if (this.asLeaf) {
                return this.wrapping.furthestPointDistance();
            }
            return 0.0d;
        }

        @Override // jsat.linear.vectorcollection.IndexNode
        public double furthestDescendantDistance() {
            return this.asLeaf ? this.wrapping.furthestPointDistance() : this.wrapping.furthestDescendantDistance();
        }

        @Override // jsat.linear.vectorcollection.IndexNode
        public int numChildren() {
            if (this.asLeaf) {
                return 0;
            }
            return this.wrapping.numChildren() + 1;
        }

        @Override // jsat.linear.vectorcollection.IndexNode
        public IndexNode getChild(int i) {
            return i == this.wrapping.numChildren() ? new SelfAsChildNode(true, this.wrapping) : new SelfAsChildNode(this.wrapping.getChild(i));
        }

        @Override // jsat.linear.vectorcollection.IndexNode
        public Vec getVec(int i) {
            return this.wrapping.getVec(i);
        }

        @Override // jsat.linear.vectorcollection.IndexNode
        public int numPoints() {
            if (this.asLeaf) {
                return this.wrapping.numPoints();
            }
            return 0;
        }

        @Override // jsat.linear.vectorcollection.IndexNode
        public int getPoint(int i) {
            if (this.asLeaf) {
                return this.wrapping.getPoint(i);
            }
            throw new IndexOutOfBoundsException("Leaf node does not have any children");
        }

        @Override // jsat.linear.vectorcollection.IndexNode
        public SelfAsChildNode<N> getParrent() {
            if (this.asLeaf && this.wrapping.hasChildren()) {
                return new SelfAsChildNode<>(false, this.wrapping);
            }
            IndexNode parrent = this.wrapping.getParrent();
            if (parrent == null) {
                return null;
            }
            return new SelfAsChildNode<>(false, parrent);
        }

        @Override // jsat.linear.vectorcollection.IndexNode
        public double minNodeDistance(SelfAsChildNode<N> selfAsChildNode) {
            return this.wrapping.minNodeDistance(selfAsChildNode.wrapping);
        }

        @Override // jsat.linear.vectorcollection.IndexNode
        public double maxNodeDistance(SelfAsChildNode<N> selfAsChildNode) {
            return this.wrapping.maxNodeDistance(selfAsChildNode.wrapping);
        }

        @Override // jsat.linear.vectorcollection.IndexNode
        public double minNodeDistance(int i) {
            return this.wrapping.minNodeDistance(i);
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof SelfAsChildNode)) {
                return false;
            }
            SelfAsChildNode selfAsChildNode = (SelfAsChildNode) obj;
            if (this.asLeaf == selfAsChildNode.asLeaf) {
                return this.wrapping.equals(selfAsChildNode.wrapping);
            }
            return false;
        }

        public int hashCode() {
            int i = (71 * 5) + (this.asLeaf ? 1 : 0);
            if (this.wrapping == null) {
                System.out.println();
            }
            return (71 * i) + this.wrapping.hashCode();
        }

        @Override // jsat.linear.vectorcollection.IndexNode
        public double[] minMaxDistance(SelfAsChildNode<N> selfAsChildNode) {
            return this.wrapping.minMaxDistance(selfAsChildNode.wrapping);
        }
    }

    IndexNode getRoot();

    DualTree<V> clone();

    default double dist(int i, int i2, DualTree<V> dualTree) {
        return getDistanceMetric().dist(get(i), dualTree.get(i));
    }

    void search(Vec vec, int i, List<Integer> list, List<Double> list2);

    /* JADX WARN: Multi-variable type inference failed */
    @Override // jsat.linear.vectorcollection.VectorCollection
    default void search(VectorCollection<V> vectorCollection, int i, List<List<Integer>> list, List<List<Double>> list2, boolean z) {
        if (!(vectorCollection instanceof DualTree)) {
            super.search(vectorCollection, i, list, list2, z);
            return;
        }
        DualTree<V> dualTree = (DualTree) vectorCollection;
        Map concurrentHashMap = z ? new ConcurrentHashMap(dualTree.size()) : new HashMap(dualTree.size());
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < dualTree.size(); i2++) {
            arrayList.add(new BoundedSortedList(i));
        }
        List<Double> accelerationCache = getAccelerationCache();
        List<Double> accelerationCache2 = dualTree.getAccelerationCache();
        int size = size();
        List mergedView = accelerationCache == null ? null : ListUtils.mergedView(accelerationCache, accelerationCache2);
        ArrayList arrayList2 = new ArrayList(size + dualTree.size());
        for (int i3 = 0; i3 < size; i3++) {
            arrayList2.add(get(i3));
        }
        for (int i4 = 0; i4 < dualTree.size(); i4++) {
            arrayList2.add(dualTree.get(i4));
        }
        DistanceMetric distanceMetric = getDistanceMetric();
        traverse(dualTree, !z ? (i5, i6) -> {
            double dist = distanceMetric.dist(i5, size + i6, (List<? extends Vec>) arrayList2, (List<Double>) mergedView);
            ((BoundedSortedList) arrayList.get(i6)).add((BoundedSortedList) new IndexDistPair(i5, dist));
            return dist;
        } : (i7, i8) -> {
            double dist = distanceMetric.dist(i7, size + i8, (List<? extends Vec>) arrayList2, (List<Double>) mergedView);
            BoundedSortedList boundedSortedList = (BoundedSortedList) arrayList.get(i8);
            synchronized (boundedSortedList) {
                boundedSortedList.add((BoundedSortedList) new IndexDistPair(i7, dist));
            }
            return dist;
        }, (indexNode, indexNode2, d) -> {
            if (d < 0.0d) {
                return indexNode.minNodeDistance(indexNode2);
            }
            double computeKnnBound = computeKnnBound(indexNode2, i, arrayList, concurrentHashMap);
            if (Double.isFinite(computeKnnBound)) {
                concurrentHashMap.put(indexNode2, Double.valueOf(computeKnnBound));
                if (d > computeKnnBound) {
                    return Double.NaN;
                }
            }
            return d;
        }, true, z);
        list.clear();
        list2.clear();
        for (int i9 = 0; i9 < dualTree.size(); i9++) {
            IntList intList = new IntList(i);
            DoubleList doubleList = new DoubleList(i);
            BoundedSortedList boundedSortedList = (BoundedSortedList) arrayList.get(i9);
            for (int i10 = 0; i10 < boundedSortedList.size(); i10++) {
                IndexDistPair indexDistPair = (IndexDistPair) boundedSortedList.get(i10);
                intList.add(indexDistPair.getIndex());
                doubleList.add(indexDistPair.getDist());
            }
            list.add(intList);
            list2.add(doubleList);
        }
    }

    default double computeKnnBound(IndexNode indexNode, int i, List<BoundedSortedList<IndexDistPair>> list, Map<IndexNode, Double> map) {
        double d = Double.NEGATIVE_INFINITY;
        int i2 = 0;
        while (true) {
            if (i2 >= indexNode.numPoints()) {
                break;
            }
            BoundedSortedList<IndexDistPair> boundedSortedList = list.get(indexNode.getPoint(i2));
            if (boundedSortedList.size() != i) {
                d = Double.POSITIVE_INFINITY;
                break;
            }
            d = Math.max(d, boundedSortedList.last().dist);
            i2++;
        }
        if (Double.isInfinite(d)) {
            d = Double.POSITIVE_INFINITY;
        } else {
            for (int i3 = 0; i3 < indexNode.numChildren(); i3++) {
                d = Math.max(d, map.getOrDefault(indexNode.getChild(i3), Double.valueOf(Double.POSITIVE_INFINITY)).doubleValue());
            }
        }
        double d2 = Double.POSITIVE_INFINITY;
        for (int i4 = 0; i4 < indexNode.numPoints(); i4++) {
            int point = indexNode.getPoint(i4);
            if (list.get(point).size() >= i) {
                d2 = Math.min(d2, list.get(point).last().dist);
            }
        }
        double furthestPointDistance = d2 + indexNode.furthestPointDistance() + indexNode.furthestDescendantDistance();
        double furthestDescendantDistance = indexNode.furthestDescendantDistance();
        double d3 = Double.POSITIVE_INFINITY;
        for (int i5 = 0; i5 < indexNode.numChildren(); i5++) {
            IndexNode child = indexNode.getChild(i5);
            d3 = Math.min(d3, map.getOrDefault(child, Double.valueOf(Double.POSITIVE_INFINITY)).doubleValue() + (2.0d * (furthestDescendantDistance - child.furthestDescendantDistance())));
        }
        IndexNode parrent = indexNode.getParrent();
        double min = Math.min(Math.min(d, furthestPointDistance), Math.min(d3, parrent == null ? Double.POSITIVE_INFINITY : map.getOrDefault(parrent, Double.valueOf(Double.POSITIVE_INFINITY)).doubleValue()));
        map.compute(indexNode, (indexNode2, d4) -> {
            if (d4 == null) {
                d4 = Double.valueOf(Double.POSITIVE_INFINITY);
            }
            return Double.valueOf(Math.min(d4.doubleValue(), min));
        });
        return min;
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    default void search(VectorCollection<V> vectorCollection, double d, double d2, List<List<Integer>> list, List<List<Double>> list2, boolean z) {
        if (!(vectorCollection instanceof DualTree)) {
            super.search(vectorCollection, d, d2, list, list2, z);
            return;
        }
        DualTree<V> dualTree = (DualTree) vectorCollection;
        list.clear();
        list2.clear();
        for (int i = 0; i < dualTree.size(); i++) {
            if (z) {
                list.add(Collections.synchronizedList(new IntList()));
                list2.add(Collections.synchronizedList(new DoubleList()));
            } else {
                list.add(new IntList());
                list2.add(new DoubleList());
            }
        }
        List<Double> accelerationCache = getAccelerationCache();
        List<Double> accelerationCache2 = dualTree.getAccelerationCache();
        int size = size();
        List mergedView = accelerationCache == null ? null : ListUtils.mergedView(accelerationCache, accelerationCache2);
        ArrayList arrayList = new ArrayList(size + dualTree.size());
        for (int i2 = 0; i2 < size; i2++) {
            arrayList.add(get(i2));
        }
        for (int i3 = 0; i3 < dualTree.size(); i3++) {
            arrayList.add(dualTree.get(i3));
        }
        DistanceMetric distanceMetric = getDistanceMetric();
        traverse(dualTree, (i4, i5) -> {
            double dist = distanceMetric.dist(i4, size + i5, (List<? extends Vec>) arrayList, (List<Double>) mergedView);
            if (d <= dist && dist <= d2) {
                ((List) list.get(i5)).add(Integer.valueOf(i4));
                ((List) list2.get(i5)).add(Double.valueOf(dist));
            }
            return dist;
        }, (indexNode, indexNode2) -> {
            double[] minMaxDistance = indexNode.minMaxDistance(indexNode2);
            double d3 = minMaxDistance[0];
            double d4 = minMaxDistance[1];
            if (d3 > d2 || d4 < d) {
                return Double.NaN;
            }
            if (d >= d3 || d4 >= d2) {
                return d3;
            }
            IntList intList = new IntList();
            Iterator<Integer> DescendantIterator = indexNode.DescendantIterator();
            while (DescendantIterator.hasNext()) {
                intList.add(DescendantIterator.next());
            }
            IntList intList2 = new IntList();
            Iterator<Integer> DescendantIterator2 = indexNode2.DescendantIterator();
            while (DescendantIterator2.hasNext()) {
                intList2.add(DescendantIterator2.next());
            }
            Iterator<Integer> it = intList.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                Iterator<Integer> it2 = intList2.iterator();
                while (it2.hasNext()) {
                    int intValue2 = it2.next().intValue();
                    double dist = distanceMetric.dist(intValue, size + intValue2, (List<? extends Vec>) arrayList, (List<Double>) mergedView);
                    ((List) list.get(intValue2)).add(Integer.valueOf(intValue));
                    ((List) list2.get(intValue2)).add(Double.valueOf(dist));
                }
            }
            return Double.NaN;
        }, false, z);
        for (int i6 = 0; i6 < list.size(); i6++) {
            IndexTable indexTable = new IndexTable(list2.get(i6));
            indexTable.apply(list2.get(i6));
            indexTable.apply(list.get(i6));
        }
    }

    default void traverse(DualTree<V> dualTree, BaseCaseDT baseCaseDT, ScoreDT scoreDT, boolean z, boolean z2) {
        IndexNode root = getRoot();
        IndexNode root2 = dualTree.getRoot();
        if (!getRoot().allPointsInLeaves()) {
            root = new SelfAsChildNode(getRoot());
            root2 = new SelfAsChildNode(dualTree.getRoot());
        }
        if (z2) {
            ForkJoinPool.commonPool().invoke(new DualTreeTraversalAction(root, root2, baseCaseDT, scoreDT, z));
        } else {
            dual_depth_first(root, root2, baseCaseDT, scoreDT, z);
        }
    }

    static void dual_depth_first(IndexNode indexNode, IndexNode indexNode2, BaseCaseDT baseCaseDT, ScoreDT scoreDT, boolean z) {
        for (int i = 0; i < indexNode.numPoints(); i++) {
            for (int i2 = 0; i2 < indexNode2.numPoints(); i2++) {
                baseCaseDT.base_case(indexNode.getPoint(i), indexNode2.getPoint(i2));
            }
        }
        PriorityQueue priorityQueue = new PriorityQueue();
        if (indexNode2.hasChildren() && indexNode.hasChildren()) {
            if (z) {
                for (int i3 = 0; i3 < indexNode2.numChildren(); i3++) {
                    IndexNode child = indexNode2.getChild(i3);
                    ArrayList arrayList = new ArrayList();
                    boolean z2 = true;
                    for (int i4 = 0; i4 < indexNode.numChildren(); i4++) {
                        IndexNode child2 = indexNode.getChild(i4);
                        double score = scoreDT.score(child2, child, -1.0d);
                        if (i4 > 0 && Math.abs(((IndexTuple) arrayList.get(i4 - 1)).priority - score) < 1.0E-13d) {
                            z2 = false;
                        }
                        arrayList.add(new IndexTuple(child2, child, score));
                    }
                    if (z2) {
                        priorityQueue.offer(new IndexTuple(indexNode, child, scoreDT.score(indexNode, child, -1.0d)));
                    } else {
                        priorityQueue.addAll(arrayList);
                    }
                }
            } else {
                for (int i5 = 0; i5 < indexNode.numChildren(); i5++) {
                    for (int i6 = 0; i6 < indexNode2.numChildren(); i6++) {
                        IndexNode child3 = indexNode.getChild(i5);
                        IndexNode child4 = indexNode2.getChild(i6);
                        double score2 = scoreDT.score(child3, child4, -1.0d);
                        if (!Double.isNaN(score2)) {
                            priorityQueue.offer(new IndexTuple(child3, child4, score2));
                        }
                    }
                }
            }
        } else if (indexNode2.hasChildren()) {
            for (int i7 = 0; i7 < indexNode2.numChildren(); i7++) {
                IndexNode child5 = indexNode2.getChild(i7);
                double score3 = scoreDT.score(indexNode, child5, -1.0d);
                if (!Double.isNaN(score3)) {
                    priorityQueue.offer(new IndexTuple(indexNode, child5, score3));
                }
            }
        } else if (indexNode.hasChildren()) {
            for (int i8 = 0; i8 < indexNode.numChildren(); i8++) {
                IndexNode child6 = indexNode.getChild(i8);
                double score4 = scoreDT.score(child6, indexNode2, -1.0d);
                if (!Double.isNaN(score4)) {
                    priorityQueue.offer(new IndexTuple(child6, indexNode2, score4));
                }
            }
        }
        while (!priorityQueue.isEmpty()) {
            IndexTuple indexTuple = (IndexTuple) priorityQueue.poll();
            if (!(scoreDT instanceof ScoreDTLazy) || !Double.isNaN(scoreDT.score(indexTuple.a, indexTuple.b, indexTuple.priority))) {
                dual_depth_first(indexTuple.a, indexTuple.b, baseCaseDT, scoreDT, z);
            }
        }
    }
}
