/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear.distancemetrics;

import java.util.List;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.utils.DoubleList;
import jsat.utils.concurrent.ParallelUtils;

public class CosineDistance
implements DistanceMetric {
    private static final long serialVersionUID = -6475546704095989078L;

    @Override
    public double dist(Vec a, Vec b) {
        if (a.length() != b.length()) {
            throw new ArithmeticException("vectors a and b are of differeing legnths " + a.length() + " and " + b.length());
        }
        double denom = a.pNorm(2.0) * b.pNorm(2.0);
        if (denom == 0.0) {
            return CosineDistance.cosineToDistance(-1.0);
        }
        return CosineDistance.cosineToDistance(Math.min(a.dot(b) / denom, 1.0));
    }

    @Override
    public boolean isSymmetric() {
        return true;
    }

    @Override
    public boolean isSubadditive() {
        return true;
    }

    @Override
    public boolean isIndiscemible() {
        return true;
    }

    @Override
    public double metricBound() {
        return 1.0;
    }

    @Override
    public String toString() {
        return "Cosine Distance";
    }

    @Override
    public CosineDistance clone() {
        return new CosineDistance();
    }

    @Override
    public boolean supportsAcceleration() {
        return true;
    }

    @Override
    public List<Double> getAccelerationCache(List<? extends Vec> vecs, boolean parallel) {
        double[] cache = new double[vecs.size()];
        ParallelUtils.run(parallel, vecs.size(), (start, end) -> {
            for (int i = start; i < end; ++i) {
                cache[i] = ((Vec)vecs.get(i)).pNorm(2.0);
            }
        });
        return DoubleList.view(cache, vecs.size());
    }

    @Override
    public double dist(int a, int b, List<? extends Vec> vecs, List<Double> cache) {
        if (cache == null) {
            return this.dist(vecs.get(a), vecs.get(b));
        }
        double denom = cache.get(a) * cache.get(b);
        if (denom == 0.0) {
            return CosineDistance.cosineToDistance(-1.0);
        }
        return CosineDistance.cosineToDistance(Math.min(vecs.get(a).dot(vecs.get(b)) / denom, 1.0));
    }

    @Override
    public double dist(int a, Vec b, List<? extends Vec> vecs, List<Double> cache) {
        if (cache == null) {
            return this.dist(vecs.get(a), b);
        }
        double denom = cache.get(a) * b.pNorm(2.0);
        if (denom == 0.0) {
            return CosineDistance.cosineToDistance(-1.0);
        }
        return CosineDistance.cosineToDistance(Math.min(vecs.get(a).dot(b) / denom, 1.0));
    }

    @Override
    public List<Double> getQueryInfo(Vec q) {
        DoubleList qi = new DoubleList(1);
        qi.add(q.pNorm(2.0));
        return qi;
    }

    @Override
    public double dist(int a, Vec b, List<Double> qi, List<? extends Vec> vecs, List<Double> cache) {
        if (cache == null) {
            return this.dist(vecs.get(a), b);
        }
        double denom = cache.get(a) * qi.get(0);
        if (denom == 0.0) {
            return CosineDistance.cosineToDistance(-1.0);
        }
        return CosineDistance.cosineToDistance(Math.min(vecs.get(a).dot(b) / denom, 1.0));
    }

    public static double cosineToDistance(double cosAngle) {
        return Math.sqrt(0.5 * (1.0 - cosAngle));
    }

    public static double distanceToCosine(double dist) {
        return 1.0 - 2.0 * (dist * dist);
    }
}

