/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear.distancemetrics;

import java.util.List;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.linear.MatrixStatistics;
import jsat.linear.Vec;
import jsat.linear.VecOps;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.regression.RegressionDataSet;
import jsat.utils.DoubleList;
import jsat.utils.concurrent.ParallelUtils;

public class NormalizedEuclideanDistance
extends TrainableDistanceMetric {
    private static final long serialVersionUID = 210109457671623688L;
    private Vec invStndDevs;

    @Override
    public <V extends Vec> void train(List<V> dataSet) {
        this.invStndDevs = MatrixStatistics.covarianceDiag(MatrixStatistics.meanVector(dataSet), dataSet);
        this.invStndDevs.applyFunction(x -> x * x);
        this.invStndDevs.applyFunction(x -> 1.0 / x);
    }

    @Override
    public <V extends Vec> void train(List<V> dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(DataSet dataSet) {
        this.invStndDevs = dataSet.getColumnMeanVariance()[1];
        this.invStndDevs.applyFunction(x -> x * x);
        this.invStndDevs.applyFunction(x -> 1.0 / x);
    }

    @Override
    public void train(DataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        this.train((DataSet)dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public boolean supportsClassificationTraining() {
        return true;
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train((DataSet)dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public boolean supportsRegressionTraining() {
        return true;
    }

    @Override
    public boolean needsTraining() {
        return this.invStndDevs == null;
    }

    @Override
    public NormalizedEuclideanDistance clone() {
        NormalizedEuclideanDistance clone = new NormalizedEuclideanDistance();
        if (this.invStndDevs != null) {
            clone.invStndDevs = this.invStndDevs.clone();
        }
        return clone;
    }

    @Override
    public double dist(Vec a, Vec b) {
        double r = VecOps.accumulateSum(this.invStndDevs, a, b, x -> x * x);
        return Math.sqrt(r);
    }

    @Override
    public boolean isSymmetric() {
        return true;
    }

    @Override
    public boolean isSubadditive() {
        return true;
    }

    @Override
    public boolean isIndiscemible() {
        return true;
    }

    @Override
    public double metricBound() {
        return Double.POSITIVE_INFINITY;
    }

    @Override
    public boolean supportsAcceleration() {
        return true;
    }

    @Override
    public List<Double> getAccelerationCache(List<? extends Vec> vecs, boolean parallel) {
        double[] cache = new double[vecs.size()];
        ParallelUtils.run(parallel, vecs.size(), (start, end) -> {
            for (int i = start; i < end; ++i) {
                Vec v = (Vec)vecs.get(i);
                cache[i] = VecOps.weightedDot(this.invStndDevs, v, v);
            }
        });
        return DoubleList.view(cache, vecs.size());
    }

    @Override
    public double dist(int a, int b, List<? extends Vec> vecs, List<Double> cache) {
        if (cache == null) {
            return this.dist(vecs.get(a), vecs.get(b));
        }
        return Math.sqrt(cache.get(a) + cache.get(b) - 2.0 * VecOps.weightedDot(this.invStndDevs, vecs.get(a), vecs.get(b)));
    }

    @Override
    public double dist(int a, Vec b, List<? extends Vec> vecs, List<Double> cache) {
        if (cache == null) {
            return this.dist(vecs.get(a), b);
        }
        return Math.sqrt(cache.get(a) + VecOps.weightedDot(this.invStndDevs, b, b) - 2.0 * VecOps.weightedDot(this.invStndDevs, vecs.get(a), b));
    }

    @Override
    public List<Double> getQueryInfo(Vec q) {
        DoubleList qi = new DoubleList(1);
        qi.add(VecOps.weightedDot(this.invStndDevs, q, q));
        return qi;
    }

    @Override
    public double dist(int a, Vec b, List<Double> qi, List<? extends Vec> vecs, List<Double> cache) {
        if (cache == null) {
            return this.dist(vecs.get(a), b);
        }
        return Math.sqrt(cache.get(a) + qi.get(0) - 2.0 * VecOps.weightedDot(this.invStndDevs, vecs.get(a), b));
    }
}

