package jsat.distributions.kernels;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import jsat.distributions.kernels.KernelPoint;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.utils.DoubleList;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/distributions/kernels/KernelPoints.class */
public class KernelPoints {
    private KernelTrick k;
    private double errorTolerance;
    private KernelPoint.BudgetStrategy budgetStrategy;
    private int maxBudget;
    private List<KernelPoint> points;

    public KernelPoints(KernelTrick kernelTrick, int i, double d) {
        this(kernelTrick, i, d, true);
    }

    public KernelPoints(KernelTrick kernelTrick, int i, double d, boolean z) {
        this.budgetStrategy = KernelPoint.BudgetStrategy.PROJECTION;
        this.maxBudget = Integer.MAX_VALUE;
        if (i < 1) {
            throw new IllegalArgumentException("Number of points must be positive, not " + i);
        }
        this.k = kernelTrick;
        this.errorTolerance = d;
        this.points = new ArrayList(i);
        this.points.add(new KernelPoint(kernelTrick, d));
        this.points.get(0).setMaxBudget(this.maxBudget);
        this.points.get(0).setBudgetStrategy(this.budgetStrategy);
        for (int i2 = 1; i2 < i; i2++) {
            addNewKernelPoint();
        }
    }

    public KernelPoints(KernelPoints kernelPoints) {
        this.budgetStrategy = KernelPoint.BudgetStrategy.PROJECTION;
        this.maxBudget = Integer.MAX_VALUE;
        this.k = kernelPoints.k.mo153clone();
        this.errorTolerance = kernelPoints.errorTolerance;
        this.points = new ArrayList(kernelPoints.points.size());
        if (kernelPoints.points.get(0).getBasisSize() == 0) {
            for (int i = 0; i < kernelPoints.points.size(); i++) {
                this.points.add(new KernelPoint(this.k, this.errorTolerance));
            }
            return;
        }
        KernelPoint m156clone = this.points.get(0).m156clone();
        for (int i2 = 1; i2 < kernelPoints.points.size(); i2++) {
            KernelPoint kernelPoint = new KernelPoint(this.k, this.errorTolerance);
            standardMove(kernelPoint, m156clone);
            kernelPoint.kernelAccel = m156clone.kernelAccel;
            kernelPoint.vecs = m156clone.vecs;
            kernelPoint.alpha = new DoubleList(kernelPoints.points.get(i2).alpha);
        }
    }

    public void setBudgetStrategy(KernelPoint.BudgetStrategy budgetStrategy) {
        this.budgetStrategy = budgetStrategy;
        Iterator<KernelPoint> it = this.points.iterator();
        while (it.hasNext()) {
            it.next().setBudgetStrategy(budgetStrategy);
        }
    }

    public KernelPoint.BudgetStrategy getBudgetStrategy() {
        return this.budgetStrategy;
    }

    public KernelTrick getKernel() {
        return this.k;
    }

    public void setErrorTolerance(double d) {
        if (Double.isNaN(d) || d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Error tolerance must be in [0, 1], not " + d);
        }
        this.errorTolerance = d;
        Iterator<KernelPoint> it = this.points.iterator();
        while (it.hasNext()) {
            it.next().setErrorTolerance(d);
        }
    }

    public double getErrorTolerance() {
        return this.errorTolerance;
    }

    public void setMaxBudget(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Budget must be positive, not " + i);
        }
        this.maxBudget = i;
        Iterator<KernelPoint> it = this.points.iterator();
        while (it.hasNext()) {
            it.next().setMaxBudget(i);
        }
    }

    public int getMaxBudget() {
        return this.maxBudget;
    }

    public double getSqrdNorm(int i) {
        return this.points.get(i).getSqrdNorm();
    }

    public double dot(int i, Vec vec, List<Double> list) {
        return this.points.get(i).dot(vec, list);
    }

    public double[] dot(Vec vec, List<Double> list) {
        double[] dArr = new double[this.points.size()];
        List<Vec> list2 = this.points.get(0).vecs;
        List<Double> list3 = this.points.get(0).kernelAccel;
        for (int i = 0; i < list2.size(); i++) {
            double eval = this.k.eval(i, vec, list, list2, list3);
            for (int i2 = 0; i2 < this.points.size(); i2++) {
                double d = this.points.get(i2).alpha.getD(i);
                if (d != 0.0d) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + (eval * d);
                }
            }
        }
        return dArr;
    }

    public double dot(int i, KernelPoint kernelPoint) {
        return this.points.get(i).dot(kernelPoint);
    }

    public double dot(int i, KernelPoints kernelPoints, int i2) {
        return this.points.get(i).dot(kernelPoints.points.get(i2));
    }

    public double dist(int i, Vec vec, List<Double> list) {
        return this.points.get(i).dist(vec, list);
    }

    public double dist(int i, KernelPoint kernelPoint) {
        return this.points.get(i).dist(kernelPoint);
    }

    public double dist(int i, KernelPoints kernelPoints, int i2) {
        return this.points.get(i).dist(kernelPoints.points.get(i2));
    }

    public void mutableMultiply(int i, double d) {
        this.points.get(i).mutableMultiply(d);
    }

    public void mutableMultiply(double d) {
        Iterator<KernelPoint> it = this.points.iterator();
        while (it.hasNext()) {
            it.next().mutableMultiply(d);
        }
    }

    public void mutableAdd(int i, double d, Vec vec, List<Double> list) {
    }

    public void mutableAdd(Vec vec, Vec vec2, List<Double> list) {
        int basisSize = getBasisSize();
        if (vec2.nnz() == 0) {
            return;
        }
        if (this.budgetStrategy == KernelPoint.BudgetStrategy.PROJECTION) {
            Iterator<IndexValue> it = vec2.iterator();
            while (it.hasNext()) {
                IndexValue next = it.next();
                int index = next.getIndex();
                KernelPoint kernelPoint = this.points.get(index);
                double value = next.getValue();
                if (kernelPoint.getBasisSize() == 0) {
                    kernelPoint.mutableAdd(value, vec, list);
                    for (int i = 0; i < this.points.size(); i++) {
                        if (i != index) {
                            KernelPoint kernelPoint2 = this.points.get(i);
                            standardMove(kernelPoint2, kernelPoint);
                            kernelPoint2.kernelAccel = kernelPoint.kernelAccel;
                            kernelPoint2.vecs = kernelPoint.vecs;
                            kernelPoint2.alpha = new DoubleList(16);
                            kernelPoint2.alpha.add(0.0d);
                        }
                    }
                } else {
                    kernelPoint.mutableAdd(value, vec, list);
                    if (basisSize != kernelPoint.getBasisSize()) {
                        for (int i2 = 0; i2 < this.points.size(); i2++) {
                            if (i2 != index) {
                                KernelPoint kernelPoint3 = this.points.get(i2);
                                standardMove(kernelPoint3, kernelPoint);
                                kernelPoint3.alpha.add(0.0d);
                            }
                        }
                    }
                }
                basisSize = getBasisSize();
            }
            return;
        }
        if (this.budgetStrategy != KernelPoint.BudgetStrategy.MERGE_RBF) {
            if (this.budgetStrategy == KernelPoint.BudgetStrategy.STOP) {
                if (getBasisSize() < this.maxBudget) {
                    this.points.get(0).vecs.add(vec);
                    if (this.points.get(0).kernelAccel != null) {
                        this.points.get(0).kernelAccel.addAll(list);
                    }
                    Iterator<IndexValue> it2 = vec2.iterator();
                    while (it2.hasNext()) {
                        IndexValue next2 = it2.next();
                        this.points.get(next2.getIndex()).alpha.add(next2.getValue());
                    }
                    addMissingZeros();
                    return;
                }
                return;
            }
            if (this.budgetStrategy != KernelPoint.BudgetStrategy.RANDOM) {
                throw new RuntimeException("BUG: Report Me!");
            }
            if (getBasisSize() >= this.maxBudget) {
                int nextInt = RandomUtil.getRandom().nextInt(getBasisSize());
                if (getBasisSize() == this.maxBudget) {
                    this.points.get(0).removeIndex(nextInt);
                }
                for (int i3 = 1; i3 < this.points.size(); i3++) {
                    this.points.get(i3).removeIndex(nextInt);
                }
            }
            this.points.get(0).vecs.add(vec);
            if (this.points.get(0).kernelAccel != null) {
                this.points.get(0).kernelAccel.addAll(list);
            }
            Iterator<IndexValue> it3 = vec2.iterator();
            while (it3.hasNext()) {
                IndexValue next3 = it3.next();
                this.points.get(next3.getIndex()).alpha.add(next3.getValue());
            }
            addMissingZeros();
            return;
        }
        Iterator<IndexValue> nonZeroIterator = vec2.getNonZeroIterator();
        if (getBasisSize() < this.maxBudget) {
            IndexValue next4 = nonZeroIterator.next();
            this.points.get(next4.getIndex()).mutableAdd(next4.getValue(), vec, list);
            while (nonZeroIterator.hasNext()) {
                IndexValue next5 = nonZeroIterator.next();
                this.points.get(next5.getIndex()).alpha.add(next5.getValue());
            }
            addMissingZeros();
            return;
        }
        KernelPoint kernelPoint4 = this.points.get(0);
        kernelPoint4.vecs.add(vec);
        if (kernelPoint4.kernelAccel != null) {
            kernelPoint4.kernelAccel.addAll(list);
        }
        Iterator<IndexValue> it4 = vec2.iterator();
        while (it4.hasNext()) {
            IndexValue next6 = it4.next();
            this.points.get(next6.getIndex()).alpha.add(next6.getValue());
        }
        addMissingZeros();
        int i4 = 0;
        double d = 0.0d;
        Iterator<KernelPoint> it5 = this.points.iterator();
        while (it5.hasNext()) {
            d += Math.pow(it5.next().alpha.getD(0), 2.0d);
        }
        for (int i5 = 1; i5 < kernelPoint4.alpha.size(); i5++) {
            double d2 = 0.0d;
            Iterator<KernelPoint> it6 = this.points.iterator();
            while (it6.hasNext()) {
                d2 += Math.pow(it6.next().alpha.getD(i5), 2.0d);
            }
            if (d2 < d) {
                d = d2;
                i4 = i5;
            }
        }
        double d3 = Double.POSITIVE_INFINITY;
        int i6 = -1;
        double d4 = 0.0d;
        double d5 = 0.001d;
        double d6 = 0.0d;
        double d7 = 0.0d;
        while (i6 == -1) {
            for (int i7 = 0; i7 < kernelPoint4.alpha.size(); i7++) {
                if (i7 != i4) {
                    double d8 = 0.0d;
                    double d9 = 0.0d;
                    for (KernelPoint kernelPoint5 : this.points) {
                        double d10 = kernelPoint5.alpha.getD(i4);
                        double d11 = kernelPoint5.alpha.getD(i7);
                        double d12 = d10 + d11;
                        if (d12 >= 1.0E-7d) {
                            d8 += d10 / d12;
                            d9 += d11 / d12;
                        }
                    }
                    if (Math.abs(d8 + d9) < d5) {
                        break;
                    }
                    double eval = this.k.eval(i7, i4, kernelPoint4.vecs, kernelPoint4.kernelAccel);
                    double h = KernelPoint.getH(eval, d8, d9);
                    double pow = Math.pow(eval, (1.0d - h) * (1.0d - h));
                    double pow2 = Math.pow(eval, h * h);
                    double d13 = 0.0d;
                    for (KernelPoint kernelPoint6 : this.points) {
                        double d14 = kernelPoint6.alpha.getD(i4);
                        double d15 = kernelPoint6.alpha.getD(i7);
                        double d16 = (d14 * pow) + (d15 * pow2);
                        d13 += (((d14 * d14) + (d15 * d15)) + (((2.0d * eval) * d14) * d15)) - (d16 * d16);
                    }
                    if (d13 < d3) {
                        d3 = d13;
                        i6 = i7;
                        d4 = h;
                        d6 = pow;
                        d7 = pow2;
                    }
                }
            }
            d5 /= 10.0d;
        }
        Vec multiply = kernelPoint4.vecs.get(i4).multiply(d4);
        multiply.mutableAdd(1.0d - d4, kernelPoint4.vecs.get(i6));
        List<Double> queryInfo = this.k.getQueryInfo(multiply);
        int i8 = 0;
        while (i8 < this.points.size()) {
            KernelPoint kernelPoint7 = this.points.get(i8);
            kernelPoint7.finalMergeStep(i4, i6, multiply, queryInfo, (kernelPoint7.alpha.getD(i4) * d6) + (kernelPoint7.alpha.getD(i6) * d7), i8 == 0);
            i8++;
        }
    }

    public void addNewKernelPoint() {
        KernelPoint kernelPoint = this.points.get(0);
        KernelPoint kernelPoint2 = new KernelPoint(this.k, this.errorTolerance);
        kernelPoint2.setMaxBudget(this.maxBudget);
        kernelPoint2.setBudgetStrategy(this.budgetStrategy);
        standardMove(kernelPoint2, kernelPoint);
        kernelPoint2.kernelAccel = kernelPoint.kernelAccel;
        kernelPoint2.vecs = kernelPoint.vecs;
        kernelPoint2.alpha = new DoubleList(kernelPoint.alpha.size());
        for (int i = 0; i < kernelPoint.alpha.size(); i++) {
            kernelPoint2.alpha.add(0.0d);
        }
        this.points.add(kernelPoint2);
    }

    private void standardMove(KernelPoint kernelPoint, KernelPoint kernelPoint2) {
        kernelPoint.InvK = kernelPoint2.InvK;
        kernelPoint.InvKExpanded = kernelPoint2.InvKExpanded;
        kernelPoint.K = kernelPoint2.K;
        kernelPoint.KExpanded = kernelPoint2.KExpanded;
    }

    public int getBasisSize() {
        return this.points.get(0).getBasisSize();
    }

    public List<Vec> getRawBasisVecs() {
        ArrayList arrayList = new ArrayList(getBasisSize());
        arrayList.addAll(this.points.get(0).vecs);
        return arrayList;
    }

    public int size() {
        return this.points.size();
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public KernelPoints m158clone() {
        return new KernelPoints(this);
    }

    private void addMissingZeros() {
        for (int i = 0; i < this.points.size(); i++) {
            while (this.points.get(i).alpha.size() < this.points.get(0).vecs.size()) {
                this.points.get(i).alpha.add(0.0d);
            }
        }
    }
}
