package jsat.linear.distancemetrics;

import java.util.List;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.linear.LUPDecomposition;
import jsat.linear.Matrix;
import jsat.linear.MatrixStatistics;
import jsat.linear.SingularValueDecomposition;
import jsat.linear.Vec;
import jsat.regression.RegressionDataSet;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:jsat/linear/distancemetrics/MahalanobisDistance.class */
public class MahalanobisDistance extends TrainableDistanceMetric {
    private static final long serialVersionUID = 7878528119699276817L;
    private boolean reTrain = true;
    private Matrix S;

    public boolean isReTrain() {
        return this.reTrain;
    }

    public void setReTrain(boolean z) {
        this.reTrain = z;
    }

    public void setInverseCovariance(Matrix matrix) {
        this.S = matrix;
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public <V extends Vec> void train(List<V> list) {
        train((List) list, false);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public <V extends Vec> void train(List<V> list, boolean z) {
        Matrix covarianceMatrix = MatrixStatistics.covarianceMatrix(MatrixStatistics.meanVector(list), list);
        LUPDecomposition lUPDecomposition = z ? new LUPDecomposition(covarianceMatrix.mo171clone(), ParallelUtils.CACHED_THREAD_POOL) : new LUPDecomposition(covarianceMatrix.mo171clone());
        double det = lUPDecomposition.det();
        if (Double.isNaN(det) || Double.isInfinite(det) || Math.abs(det) <= 1.0E-13d) {
            this.S = new SingularValueDecomposition(covarianceMatrix).getPseudoInverse();
        } else if (z) {
            this.S = lUPDecomposition.solve(Matrix.eye(covarianceMatrix.cols()), ParallelUtils.CACHED_THREAD_POOL);
        } else {
            this.S = lUPDecomposition.solve(Matrix.eye(covarianceMatrix.cols()));
        }
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(DataSet dataSet) {
        train(dataSet, false);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(DataSet dataSet, boolean z) {
        train(dataSet.getDataVectors(), z);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(ClassificationDataSet classificationDataSet) {
        train((DataSet) classificationDataSet);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        train((DataSet) classificationDataSet, z);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public boolean supportsClassificationTraining() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(RegressionDataSet regressionDataSet) {
        train((DataSet) regressionDataSet);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public void train(RegressionDataSet regressionDataSet, boolean z) {
        train((DataSet) regressionDataSet, z);
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public boolean supportsRegressionTraining() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    public boolean needsTraining() {
        if (this.S == null) {
            return true;
        }
        return isReTrain();
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(Vec vec, Vec vec2) {
        Vec subtract = vec.subtract(vec2);
        return Math.sqrt(subtract.dot(this.S.multiply(subtract)));
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean isSymmetric() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean isSubadditive() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean isIndiscemible() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double metricBound() {
        return Double.POSITIVE_INFINITY;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public String toString() {
        return "Mahalanobis Distance";
    }

    @Override // jsat.linear.distancemetrics.TrainableDistanceMetric
    /* renamed from: clone */
    public MahalanobisDistance mo185clone() {
        MahalanobisDistance mahalanobisDistance = new MahalanobisDistance();
        mahalanobisDistance.reTrain = this.reTrain;
        if (this.S != null) {
            mahalanobisDistance.S = this.S.mo171clone();
        }
        return mahalanobisDistance;
    }
}
