/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear.distancemetrics;

import java.util.List;
import jsat.linear.Vec;
import jsat.linear.VecOps;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.utils.DoubleList;
import jsat.utils.concurrent.ParallelUtils;

public class WeightedEuclideanDistance
implements DistanceMetric {
    private static final long serialVersionUID = 2959997330647828673L;
    private Vec w;

    public WeightedEuclideanDistance(Vec w) {
        this.setWeight(w);
    }

    public Vec getWeight() {
        return this.w;
    }

    public void setWeight(Vec w) {
        if (w == null) {
            throw new NullPointerException();
        }
        this.w = w;
    }

    @Override
    public double dist(Vec a, Vec b) {
        return Math.sqrt(VecOps.accumulateSum(this.w, a, b, x -> x * x));
    }

    @Override
    public boolean isSymmetric() {
        return true;
    }

    @Override
    public boolean isSubadditive() {
        return true;
    }

    @Override
    public boolean isIndiscemible() {
        return true;
    }

    @Override
    public double metricBound() {
        return Double.POSITIVE_INFINITY;
    }

    @Override
    public WeightedEuclideanDistance clone() {
        return new WeightedEuclideanDistance(this.w.clone());
    }

    @Override
    public boolean supportsAcceleration() {
        return true;
    }

    @Override
    public List<Double> getAccelerationCache(List<? extends Vec> vecs, boolean parallel) {
        double[] cache = new double[vecs.size()];
        ParallelUtils.run(parallel, vecs.size(), (start, end) -> {
            for (int i = start; i < end; ++i) {
                Vec v = (Vec)vecs.get(i);
                cache[i] = VecOps.weightedDot(this.w, v, v);
            }
        });
        return DoubleList.view(cache, vecs.size());
    }

    @Override
    public double dist(int a, int b, List<? extends Vec> vecs, List<Double> cache) {
        if (cache == null) {
            return this.dist(vecs.get(a), vecs.get(b));
        }
        return Math.sqrt(cache.get(a) + cache.get(b) - 2.0 * VecOps.weightedDot(this.w, vecs.get(a), vecs.get(b)));
    }

    @Override
    public double dist(int a, Vec b, List<? extends Vec> vecs, List<Double> cache) {
        if (cache == null) {
            return this.dist(vecs.get(a), b);
        }
        return Math.sqrt(cache.get(a) + VecOps.weightedDot(this.w, b, b) - 2.0 * VecOps.weightedDot(this.w, vecs.get(a), b));
    }

    @Override
    public List<Double> getQueryInfo(Vec q) {
        DoubleList qi = new DoubleList(1);
        qi.add(VecOps.weightedDot(this.w, q, q));
        return qi;
    }

    @Override
    public double dist(int a, Vec b, List<Double> qi, List<? extends Vec> vecs, List<Double> cache) {
        if (cache == null) {
            return this.dist(vecs.get(a), b);
        }
        return Math.sqrt(cache.get(a) + qi.get(0) - 2.0 * VecOps.weightedDot(this.w, vecs.get(a), b));
    }
}

