package jsat.distributions.multivariate;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import jsat.distributions.empirical.kernelfunc.EpanechnikovKF;
import jsat.distributions.empirical.kernelfunc.KernelFunction;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.DefaultVectorCollection;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionUtils;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameterized;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:jsat/distributions/multivariate/MetricKDE.class */
public class MetricKDE extends MultivariateKDE implements Parameterized {
    private static final long serialVersionUID = -2084039950938740815L;
    private KernelFunction kf;
    private double bandwidth;
    private DistanceMetric distanceMetric;
    private VectorCollection<VecPaired<Vec, Integer>> vc;
    private int defaultK;
    private double defaultStndDev;
    public static final int DEFAULT_K = 3;
    public static final double DEFAULT_STND_DEV = 2.0d;
    private static final VectorCollection<VecPaired<Vec, Integer>> defaultVC = new DefaultVectorCollection();
    public static final KernelFunction DEFAULT_KF = EpanechnikovKF.getInstance();

    public MetricKDE() {
        this(DEFAULT_KF, new EuclideanDistance(), defaultVC);
    }

    public MetricKDE(DistanceMetric distanceMetric) {
        this(DEFAULT_KF, distanceMetric, defaultVC);
    }

    public MetricKDE(DistanceMetric distanceMetric, VectorCollection<VecPaired<Vec, Integer>> vectorCollection) {
        this(DEFAULT_KF, distanceMetric, vectorCollection);
    }

    public MetricKDE(KernelFunction kernelFunction, DistanceMetric distanceMetric) {
        this(kernelFunction, distanceMetric, new DefaultVectorCollection());
    }

    public MetricKDE(KernelFunction kernelFunction, DistanceMetric distanceMetric, VectorCollection<VecPaired<Vec, Integer>> vectorCollection) {
        this(kernelFunction, distanceMetric, vectorCollection, 3, 2.0d);
    }

    public MetricKDE(KernelFunction kernelFunction, DistanceMetric distanceMetric, VectorCollection<VecPaired<Vec, Integer>> vectorCollection, int i, double d) {
        setKernelFunction(kernelFunction);
        this.distanceMetric = distanceMetric;
        this.vc = vectorCollection;
        setDefaultK(i);
        setDefaultStndDev(d);
    }

    public void setBandwith(double d) {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new ArithmeticException("Invalid bandwith given, bandwith must be a positive number, not " + d);
        }
        this.bandwidth = d;
    }

    public double getBandwith() {
        return this.bandwidth;
    }

    public void setDefaultK(int i) {
        if (i <= 0) {
            throw new ArithmeticException("At least one neighbor must be taken into acount, " + i + " is invalid");
        }
        this.defaultK = i;
    }

    public int getDefaultK() {
        return this.defaultK;
    }

    public void setDefaultStndDev(double d) {
        if (Double.isInfinite(d) || Double.isNaN(d) || d <= 0.0d) {
            throw new ArithmeticException("The number of standard deviations to remove must bea postive number, not " + d);
        }
        this.defaultStndDev = d;
    }

    public double getDefaultStndDev() {
        return this.defaultStndDev;
    }

    public DistanceMetric getDistanceMetric() {
        return this.distanceMetric;
    }

    public void setDistanceMetric(DistanceMetric distanceMetric) {
        this.distanceMetric = distanceMetric;
    }

    @Override // jsat.distributions.multivariate.MultivariateKDE, jsat.distributions.multivariate.MultivariateDistributionSkeleton
    public MetricKDE clone() {
        MetricKDE metricKDE = new MetricKDE(this.kf, this.distanceMetric.mo185clone(), this.vc.m199clone(), this.defaultK, this.defaultStndDev);
        metricKDE.bandwidth = this.bandwidth;
        if (this.vc != null) {
            metricKDE.vc = this.vc.m199clone();
        }
        return metricKDE;
    }

    @Override // jsat.distributions.multivariate.MultivariateKDE
    public List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> getNearby(Vec vec) {
        if (this.vc == null) {
            throw new UntrainedModelException("Model has not yet been created");
        }
        List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> nearbyRaw = getNearbyRaw(vec);
        for (VecPaired<VecPaired<Vec, Integer>, Double> vecPaired : nearbyRaw) {
            vecPaired.setPair(Double.valueOf(this.kf.k(vecPaired.getPair().doubleValue())));
        }
        return nearbyRaw;
    }

    @Override // jsat.distributions.multivariate.MultivariateKDE
    public List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> getNearbyRaw(Vec vec) {
        if (this.vc == null) {
            throw new UntrainedModelException("Model has not yet been created");
        }
        List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> search = this.vc.search(vec, this.bandwidth * this.kf.cutOff());
        for (VecPaired<VecPaired<Vec, Integer>, Double> vecPaired : search) {
            vecPaired.setPair(Double.valueOf(vecPaired.getPair().doubleValue() / this.bandwidth));
        }
        return search;
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public double pdf(Vec vec) {
        List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> nearby = getNearby(vec);
        if (nearby.isEmpty()) {
            return 0.0d;
        }
        double d = 0.0d;
        Iterator<? extends VecPaired<VecPaired<Vec, Integer>, Double>> it = nearby.iterator();
        while (it.hasNext()) {
            d += it.next().getPair().doubleValue();
        }
        return d / (this.vc.size() * Math.pow(this.bandwidth, nearby.get(0).length()));
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public <V extends Vec> boolean setUsingData(List<V> list, boolean z) {
        ExecutorService newExecutor = ParallelUtils.getNewExecutor(z);
        boolean usingData = setUsingData(list, newExecutor);
        newExecutor.shutdownNow();
        return usingData;
    }

    public <V extends Vec> boolean setUsingData(List<V> list, double d) {
        return setUsingData(list, d, (ExecutorService) null);
    }

    public <V extends Vec> boolean setUsingData(List<V> list, double d, ExecutorService executorService) {
        setBandwith(d);
        ArrayList arrayList = new ArrayList(list.size());
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(new VecPaired(list.get(i), Integer.valueOf(i)));
        }
        TrainableDistanceMetric.trainIfNeeded(this.distanceMetric, list, executorService);
        this.vc.build(executorService != null, arrayList, this.distanceMetric);
        return true;
    }

    public <V extends Vec> boolean setUsingData(List<V> list, int i) {
        return setUsingData(list, i, this.defaultStndDev);
    }

    public <V extends Vec> boolean setUsingData(List<V> list, int i, ExecutorService executorService) {
        return setUsingData(list, i, this.defaultStndDev, executorService);
    }

    public <V extends Vec> boolean setUsingData(List<V> list, int i, double d) {
        return setUsingData(list, i, d, null);
    }

    public <V extends Vec> boolean setUsingData(List<V> list, int i, double d, ExecutorService executorService) {
        ArrayList arrayList = new ArrayList(list.size());
        for (int i2 = 0; i2 < list.size(); i2++) {
            arrayList.add(new VecPaired(list.get(i2), Integer.valueOf(i2)));
        }
        TrainableDistanceMetric.trainIfNeeded(this.distanceMetric, list, executorService);
        this.vc.build(arrayList, this.distanceMetric);
        OnLineStatistics kthNeighborStats = executorService == null ? VectorCollectionUtils.getKthNeighborStats(this.vc, list, i + 1) : VectorCollectionUtils.getKthNeighborStats(this.vc, list, i + 1, executorService);
        setBandwith(kthNeighborStats.getMean() + (kthNeighborStats.getStandardDeviation() * d));
        return true;
    }

    public <V extends Vec> boolean setUsingData(List<V> list, ExecutorService executorService) {
        return setUsingData((List) list, this.defaultK, executorService);
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public List<Vec> sample(int i, Random random) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override // jsat.distributions.multivariate.MultivariateKDE
    public KernelFunction getKernelFunction() {
        return this.kf;
    }

    public void setKernelFunction(KernelFunction kernelFunction) {
        this.kf = kernelFunction;
    }

    @Override // jsat.distributions.multivariate.MultivariateKDE
    public void scaleBandwidth(double d) {
        this.bandwidth *= d;
    }
}
