package jsat.linear.distancemetrics;

import java.util.List;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.linear.MatrixStatistics;
import jsat.linear.Vec;
import jsat.linear.VecOps;
import jsat.regression.RegressionDataSet;
import jsat.utils.DoubleList;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:jsat/linear/distancemetrics/NormalizedEuclideanDistance.class */
public class NormalizedEuclideanDistance extends TrainableDistanceMetric {
    private static final long serialVersionUID = 210109457671623688L;
    private Vec invStndDevs;

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public <V extends Vec> void train(List<V> list) {
        this.invStndDevs = MatrixStatistics.covarianceDiag(MatrixStatistics.meanVector(list), list);
        this.invStndDevs.applyFunction(d -> {
            return d * d;
        });
        this.invStndDevs.applyFunction(d2 -> {
            return 1.0d / d2;
        });
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public <V extends Vec> void train(List<V> list, boolean z) {
        train(list);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(DataSet dataSet) {
        this.invStndDevs = dataSet.getColumnMeanVariance()[1];
        this.invStndDevs.applyFunction(d -> {
            return d * d;
        });
        this.invStndDevs.applyFunction(d2 -> {
            return 1.0d / d2;
        });
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(DataSet dataSet, boolean z) {
        train(dataSet);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(ClassificationDataSet classificationDataSet) {
        train((DataSet) classificationDataSet);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        train(classificationDataSet);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public boolean supportsClassificationTraining() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(RegressionDataSet regressionDataSet) {
        train((DataSet) regressionDataSet);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(RegressionDataSet regressionDataSet, boolean z) {
        train(regressionDataSet);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public boolean supportsRegressionTraining() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public boolean needsTraining() {
        return this.invStndDevs == null;
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    /* renamed from: clone */
    public NormalizedEuclideanDistance mo185clone() {
        NormalizedEuclideanDistance normalizedEuclideanDistance = new NormalizedEuclideanDistance();
        if (this.invStndDevs != null) {
            normalizedEuclideanDistance.invStndDevs = this.invStndDevs.mo46clone();
        }
        return normalizedEuclideanDistance;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(Vec vec, Vec vec2) {
        return Math.sqrt(VecOps.accumulateSum(this.invStndDevs, vec, vec2, d -> {
            return d * d;
        }));
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean isSymmetric() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean isSubadditive() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean isIndiscemible() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double metricBound() {
        return Double.POSITIVE_INFINITY;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean supportsAcceleration() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public List<Double> getAccelerationCache(List<? extends Vec> list, boolean z) {
        double[] dArr = new double[list.size()];
        ParallelUtils.run(z, list.size(), (i, i2) -> {
            for (int i = i; i < i2; i++) {
                Vec vec = (Vec) list.get(i);
                dArr[i] = VecOps.weightedDot(this.invStndDevs, vec, vec);
            }
        });
        return DoubleList.view(dArr, list.size());
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(int i, int i2, List<? extends Vec> list, List<Double> list2) {
        return list2 == null ? dist(list.get(i), list.get(i2)) : Math.sqrt((list2.get(i).doubleValue() + list2.get(i2).doubleValue()) - (2.0d * VecOps.weightedDot(this.invStndDevs, list.get(i), list.get(i2))));
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(int i, Vec vec, List<? extends Vec> list, List<Double> list2) {
        return list2 == null ? dist(list.get(i), vec) : Math.sqrt((list2.get(i).doubleValue() + VecOps.weightedDot(this.invStndDevs, vec, vec)) - (2.0d * VecOps.weightedDot(this.invStndDevs, list.get(i), vec)));
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public List<Double> getQueryInfo(Vec vec) {
        DoubleList doubleList = new DoubleList(1);
        doubleList.add(VecOps.weightedDot(this.invStndDevs, vec, vec));
        return doubleList;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(int i, Vec vec, List<Double> list, List<? extends Vec> list2, List<Double> list3) {
        return list3 == null ? dist(list2.get(i), vec) : Math.sqrt((list3.get(i).doubleValue() + list.get(0).doubleValue()) - (2.0d * VecOps.weightedDot(this.invStndDevs, list2.get(i), vec)));
    }
}
