/*
 * Decompiled with CFR 0.152.
 */
package smile.neighbor;

import java.lang.reflect.Array;
import java.util.List;
import smile.math.Math;
import smile.neighbor.KNNSearch;
import smile.neighbor.NearestNeighborSearch;
import smile.neighbor.Neighbor;
import smile.neighbor.RNNSearch;
import smile.sort.HeapSelect;

public class KDTree<E>
implements NearestNeighborSearch<double[], E>,
KNNSearch<double[], E>,
RNNSearch<double[], E> {
    private double[][] keys;
    private E[] data;
    private Node root;
    private int[] index;
    private boolean identicalExcluded = true;

    public KDTree(double[][] key, E[] data) {
        if (key.length != data.length) {
            throw new IllegalArgumentException("The array size of keys and data are different.");
        }
        this.keys = key;
        this.data = data;
        int n = key.length;
        this.index = new int[n];
        for (int i = 0; i < n; ++i) {
            this.index[i] = i;
        }
        this.root = this.buildNode(0, n);
    }

    public String toString() {
        return "KD-Tree";
    }

    private Node buildNode(int begin, int end) {
        int i;
        int d = this.keys[0].length;
        Node node = new Node();
        node.count = end - begin;
        node.index = begin;
        double[] lowerBound = new double[d];
        double[] upperBound = new double[d];
        for (i = 0; i < d; ++i) {
            lowerBound[i] = this.keys[this.index[begin]][i];
            upperBound[i] = this.keys[this.index[begin]][i];
        }
        for (i = begin + 1; i < end; ++i) {
            for (int j = 0; j < d; ++j) {
                double c = this.keys[this.index[i]][j];
                if (lowerBound[j] > c) {
                    lowerBound[j] = c;
                }
                if (!(upperBound[j] < c)) continue;
                upperBound[j] = c;
            }
        }
        double maxRadius = -1.0;
        for (int i2 = 0; i2 < d; ++i2) {
            double radius = (upperBound[i2] - lowerBound[i2]) / 2.0;
            if (!(radius > maxRadius)) continue;
            maxRadius = radius;
            node.split = i2;
            node.cutoff = (upperBound[i2] + lowerBound[i2]) / 2.0;
        }
        if (maxRadius == 0.0) {
            node.upper = null;
            node.lower = null;
            return node;
        }
        int i1 = begin;
        int i2 = end - 1;
        int size = 0;
        while (i1 <= i2) {
            boolean i2Good;
            boolean i1Good = this.keys[this.index[i1]][node.split] < node.cutoff;
            boolean bl = i2Good = this.keys[this.index[i2]][node.split] >= node.cutoff;
            if (!i1Good && !i2Good) {
                int temp = this.index[i1];
                this.index[i1] = this.index[i2];
                this.index[i2] = temp;
                i2Good = true;
                i1Good = true;
            }
            if (i1Good) {
                ++i1;
                ++size;
            }
            if (!i2Good) continue;
            --i2;
        }
        node.lower = this.buildNode(begin, begin + size);
        node.upper = this.buildNode(begin + size, end);
        return node;
    }

    public KDTree<E> setIdenticalExcluded(boolean excluded) {
        this.identicalExcluded = excluded;
        return this;
    }

    public boolean isIdenticalExcluded() {
        return this.identicalExcluded;
    }

    private void search(double[] q, Node node, Neighbor<double[], E> neighbor) {
        if (node.isLeaf()) {
            for (int idx = node.index; idx < node.index + node.count; ++idx) {
                double distance;
                if (q == this.keys[this.index[idx]] && this.identicalExcluded || !((distance = Math.squaredDistance(q, this.keys[this.index[idx]])) < neighbor.distance)) continue;
                neighbor.key = this.keys[this.index[idx]];
                neighbor.value = this.data[this.index[idx]];
                neighbor.index = this.index[idx];
                neighbor.distance = distance;
            }
        } else {
            Node further;
            Node nearer;
            double diff = q[node.split] - node.cutoff;
            if (diff < 0.0) {
                nearer = node.lower;
                further = node.upper;
            } else {
                nearer = node.upper;
                further = node.lower;
            }
            this.search(q, nearer, neighbor);
            if (neighbor.distance >= diff * diff) {
                this.search(q, further, neighbor);
            }
        }
    }

    private void search(double[] q, Node node, HeapSelect<Neighbor<double[], E>> heap) {
        if (node.isLeaf()) {
            for (int idx = node.index; idx < node.index + node.count; ++idx) {
                if (q == this.keys[this.index[idx]] && this.identicalExcluded) continue;
                double distance = Math.squaredDistance(q, this.keys[this.index[idx]]);
                Neighbor<double[], E> datum = heap.peek();
                if (!(distance < datum.distance)) continue;
                datum.distance = distance;
                datum.index = this.index[idx];
                datum.key = this.keys[this.index[idx]];
                datum.value = this.data[this.index[idx]];
                heap.heapify();
            }
        } else {
            Node further;
            Node nearer;
            double diff = q[node.split] - node.cutoff;
            if (diff < 0.0) {
                nearer = node.lower;
                further = node.upper;
            } else {
                nearer = node.upper;
                further = node.lower;
            }
            this.search(q, nearer, heap);
            if (heap.peek().distance >= diff * diff) {
                this.search(q, further, heap);
            }
        }
    }

    private void search(double[] q, Node node, double radius, List<Neighbor<double[], E>> neighbors) {
        if (node.isLeaf()) {
            for (int idx = node.index; idx < node.index + node.count; ++idx) {
                double distance;
                if (q == this.keys[this.index[idx]] && this.identicalExcluded || !((distance = Math.distance(q, this.keys[this.index[idx]])) <= radius)) continue;
                neighbors.add(new Neighbor<double[], E>(this.keys[this.index[idx]], this.data[this.index[idx]], this.index[idx], distance));
            }
        } else {
            Node further;
            Node nearer;
            double diff = q[node.split] - node.cutoff;
            if (diff < 0.0) {
                nearer = node.lower;
                further = node.upper;
            } else {
                nearer = node.upper;
                further = node.lower;
            }
            this.search(q, nearer, radius, neighbors);
            if (radius >= Math.abs(diff)) {
                this.search(q, further, radius, neighbors);
            }
        }
    }

    @Override
    public Neighbor<double[], E> nearest(double[] q) {
        Neighbor<Object, Object> neighbor = new Neighbor<Object, Object>(null, null, 0, Double.MAX_VALUE);
        this.search(q, this.root, neighbor);
        neighbor.distance = Math.sqrt(neighbor.distance);
        return neighbor;
    }

    @Override
    public Neighbor<double[], E>[] knn(double[] q, int k) {
        int i;
        if (k <= 0) {
            throw new IllegalArgumentException("Invalid k: " + k);
        }
        if (k > this.keys.length) {
            throw new IllegalArgumentException("Neighbor array length is larger than the dataset size");
        }
        Neighbor<Object, Object> neighbor = new Neighbor<Object, Object>(null, null, 0, Double.MAX_VALUE);
        Comparable[] neighbors = (Neighbor[])Array.newInstance(neighbor.getClass(), k);
        HeapSelect heap = new HeapSelect(neighbors);
        for (i = 0; i < k; ++i) {
            heap.add(neighbor);
            neighbor = new Neighbor<Object, Object>(null, null, 0, Double.MAX_VALUE);
        }
        this.search(q, this.root, heap);
        heap.sort();
        for (i = 0; i < neighbors.length; ++i) {
            ((Neighbor)neighbors[i]).distance = Math.sqrt(((Neighbor)neighbors[i]).distance);
        }
        return neighbors;
    }

    @Override
    public void range(double[] q, double radius, List<Neighbor<double[], E>> neighbors) {
        if (radius <= 0.0) {
            throw new IllegalArgumentException("Invalid radius: " + radius);
        }
        this.search(q, this.root, radius, neighbors);
    }

    class Node {
        int count;
        int index;
        int split;
        double cutoff;
        Node lower;
        Node upper;

        Node() {
        }

        boolean isLeaf() {
            return this.lower == null && this.upper == null;
        }
    }
}

