/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear.distancemetrics;

import java.util.Iterator;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;

public class PearsonDistance
implements DistanceMetric {
    private static final long serialVersionUID = 1090726755301934198L;
    private boolean bothNonZero;
    private boolean absoluteDistance;

    public PearsonDistance() {
        this(false, false);
    }

    public PearsonDistance(boolean bothNonZero, boolean absoluteDistance) {
        this.bothNonZero = bothNonZero;
        this.absoluteDistance = absoluteDistance;
    }

    @Override
    public double dist(Vec a, Vec b) {
        double r = PearsonDistance.correlation(a, b, this.bothNonZero);
        if (Double.isNaN(r)) {
            return Double.MAX_VALUE;
        }
        if (this.absoluteDistance) {
            return Math.sqrt(1.0 - r * r);
        }
        return Math.sqrt((1.0 - r) * 0.5);
    }

    @Override
    public boolean isSymmetric() {
        return true;
    }

    @Override
    public boolean isSubadditive() {
        return true;
    }

    @Override
    public boolean isIndiscemible() {
        return true;
    }

    @Override
    public double metricBound() {
        return 1.0;
    }

    @Override
    public PearsonDistance clone() {
        return new PearsonDistance(this.bothNonZero, this.absoluteDistance);
    }

    public static double correlation(Vec a, Vec b, boolean bothNonZero) {
        double bMean;
        double aMean;
        if (bothNonZero) {
            aMean = a.sum() / (double)a.nnz();
            bMean = b.sum() / (double)b.nnz();
        } else {
            aMean = a.mean();
            bMean = b.mean();
        }
        double r = 0.0;
        double aSqrd = 0.0;
        double bSqrd = 0.0;
        if (a.isSparse() || b.isSparse()) {
            double bVal;
            double aVal;
            Iterator<IndexValue> aIter = a.getNonZeroIterator();
            Iterator<IndexValue> bIter = b.getNonZeroIterator();
            if (!aIter.hasNext() && !bIter.hasNext()) {
                return 1.0;
            }
            if (!aIter.hasNext() || !bIter.hasNext()) {
                return Double.MAX_VALUE;
            }
            IndexValue aCur = null;
            IndexValue bCur = null;
            boolean newA = true;
            boolean newB = true;
            int lastObservedIndex = -1;
            while (true) {
                if (newA) {
                    if (!aIter.hasNext()) break;
                    aCur = aIter.next();
                    newA = false;
                }
                if (newB) {
                    if (!bIter.hasNext()) break;
                    bCur = bIter.next();
                    newB = false;
                }
                if (aCur.getIndex() == bCur.getIndex()) {
                    if (!bothNonZero) {
                        r += aMean * bMean * (double)(aCur.getIndex() - lastObservedIndex - 1);
                    }
                    lastObservedIndex = aCur.getIndex();
                    aVal = aCur.getValue() - aMean;
                    double bVal2 = bCur.getValue() - bMean;
                    r += aVal * bVal2;
                    aSqrd += aVal * aVal;
                    bSqrd += bVal2 * bVal2;
                    newB = true;
                    newA = true;
                    continue;
                }
                if (aCur.getIndex() > bCur.getIndex()) {
                    if (!bothNonZero) {
                        r += aMean * bMean * (double)(bCur.getIndex() - lastObservedIndex - 1);
                        lastObservedIndex = bCur.getIndex();
                        bVal = bCur.getValue() - bMean;
                        r += -aMean * bVal;
                        bSqrd += bVal * bVal;
                    }
                    newB = true;
                    continue;
                }
                if (aCur.getIndex() >= bCur.getIndex()) continue;
                if (!bothNonZero) {
                    r += aMean * bMean * (double)(aCur.getIndex() - lastObservedIndex - 1);
                    lastObservedIndex = aCur.getIndex();
                    aVal = aCur.getValue() - aMean;
                    r += aVal * -bMean;
                    aSqrd += aVal * aVal;
                }
                newA = true;
            }
            if (!bothNonZero) {
                while (!newA || newA && aIter.hasNext()) {
                    if (newA) {
                        aCur = aIter.next();
                    }
                    r += aMean * bMean * (double)(aCur.getIndex() - lastObservedIndex - 1);
                    lastObservedIndex = aCur.getIndex();
                    aVal = aCur.getValue() - aMean;
                    r += aVal * -bMean;
                    aSqrd += aVal * aVal;
                    newA = true;
                }
                while (!newB || newB && bIter.hasNext()) {
                    if (newB) {
                        bCur = bIter.next();
                    }
                    r += aMean * bMean * (double)(bCur.getIndex() - lastObservedIndex - 1);
                    lastObservedIndex = bCur.getIndex();
                    bVal = bCur.getValue() - bMean;
                    r += -aMean * bVal;
                    bSqrd += bVal * bVal;
                    newB = true;
                }
                r += aMean * bMean * (double)(a.length() - lastObservedIndex - 1);
                aSqrd += aMean * aMean * (double)(a.length() - a.nnz());
                bSqrd += bMean * bMean * (double)(b.length() - b.nnz());
            }
        } else {
            for (int i = 0; i < a.length(); ++i) {
                double aTmp = a.get(i);
                double bTmp = b.get(i);
                if (bothNonZero && (aTmp == 0.0 || bTmp == 0.0)) continue;
                double aVal = aTmp - aMean;
                double bVal = bTmp - bMean;
                r += aVal * bVal;
                aSqrd += aVal * aVal;
                bSqrd += bVal * bVal;
            }
        }
        if (bSqrd == 0.0 && aSqrd == 0.0) {
            return 0.0;
        }
        if (bSqrd == 0.0 || aSqrd == 0.0) {
            return r / Math.sqrt((aSqrd + 1.0E-10) * (bSqrd + 1.0E-10));
        }
        return r / Math.sqrt(aSqrd * bSqrd);
    }
}

