/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear.distancemetrics;

import java.util.List;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.linear.LUPDecomposition;
import jsat.linear.Matrix;
import jsat.linear.MatrixStatistics;
import jsat.linear.SingularValueDecomposition;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.regression.RegressionDataSet;
import jsat.utils.concurrent.ParallelUtils;

public class MahalanobisDistance
extends TrainableDistanceMetric {
    private static final long serialVersionUID = 7878528119699276817L;
    private boolean reTrain = true;
    private Matrix S;

    public boolean isReTrain() {
        return this.reTrain;
    }

    public void setReTrain(boolean reTrain) {
        this.reTrain = reTrain;
    }

    public void setInverseCovariance(Matrix S) {
        this.S = S;
    }

    @Override
    public <V extends Vec> void train(List<V> dataSet) {
        this.train(dataSet, false);
    }

    @Override
    public <V extends Vec> void train(List<V> dataSet, boolean parallel) {
        Vec mean = MatrixStatistics.meanVector(dataSet);
        Matrix covariance = MatrixStatistics.covarianceMatrix(mean, dataSet);
        LUPDecomposition lup = parallel ? new LUPDecomposition(covariance.clone(), ParallelUtils.CACHED_THREAD_POOL) : new LUPDecomposition(covariance.clone());
        double det = lup.det();
        if (Double.isNaN(det) || Double.isInfinite(det) || Math.abs(det) <= 1.0E-13) {
            lup = null;
            SingularValueDecomposition svd = new SingularValueDecomposition(covariance);
            this.S = svd.getPseudoInverse();
        } else {
            this.S = parallel ? lup.solve(Matrix.eye(covariance.cols()), ParallelUtils.CACHED_THREAD_POOL) : lup.solve(Matrix.eye(covariance.cols()));
        }
    }

    @Override
    public void train(DataSet dataSet) {
        this.train(dataSet, false);
    }

    @Override
    public void train(DataSet dataSet, boolean parallel) {
        this.train(dataSet.getDataVectors(), parallel);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        this.train((DataSet)dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.train((DataSet)dataSet, parallel);
    }

    @Override
    public boolean supportsClassificationTraining() {
        return true;
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train((DataSet)dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.train((DataSet)dataSet, parallel);
    }

    @Override
    public boolean supportsRegressionTraining() {
        return true;
    }

    @Override
    public boolean needsTraining() {
        if (this.S == null) {
            return true;
        }
        return this.isReTrain();
    }

    @Override
    public double dist(Vec a, Vec b) {
        Vec aMb = a.subtract(b);
        return Math.sqrt(aMb.dot(this.S.multiply(aMb)));
    }

    @Override
    public boolean isSymmetric() {
        return true;
    }

    @Override
    public boolean isSubadditive() {
        return true;
    }

    @Override
    public boolean isIndiscemible() {
        return true;
    }

    @Override
    public double metricBound() {
        return Double.POSITIVE_INFINITY;
    }

    @Override
    public String toString() {
        return "Mahalanobis Distance";
    }

    @Override
    public MahalanobisDistance clone() {
        MahalanobisDistance clone = new MahalanobisDistance();
        clone.reTrain = this.reTrain;
        if (this.S != null) {
            clone.S = this.S.clone();
        }
        return clone;
    }
}

