/*
 * Decompiled with CFR 0.152.
 */
package org.ujmp.core.doublematrix.calculation.general.misc;

import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.doublematrix.calculation.AbstractDoubleCalculation;
import org.ujmp.core.util.MathUtil;
import org.ujmp.core.util.VerifyUtil;

public class CosineSimilarity
extends AbstractDoubleCalculation {
    private static final long serialVersionUID = 9144182368353349242L;
    private final long[] size;
    private final boolean ignoreNaN;

    public CosineSimilarity(Matrix matrix, boolean ignoreNaN) {
        super(matrix);
        this.size = new long[]{matrix.getRowCount(), matrix.getRowCount()};
        this.ignoreNaN = ignoreNaN;
    }

    @Override
    public double getDouble(long ... coordinates) {
        Matrix m1 = this.getSource().selectRows(Calculation.Ret.LINK, coordinates[0]);
        Matrix m2 = this.getSource().selectRows(Calculation.Ret.LINK, coordinates[1]);
        double aiSum = 0.0;
        double a2Sum = 0.0;
        double b2Sum = 0.0;
        for (long i = 0L; i < m1.getColumnCount(); ++i) {
            double a = m1.getAsDouble(0L, i);
            double b = m2.getAsDouble(0L, i);
            if (this.ignoreNaN) {
                if (MathUtil.isNaNOrInfinite(a) || MathUtil.isNaNOrInfinite(b)) continue;
                aiSum += a * b;
                a2Sum += a * a;
                b2Sum += b * b;
                continue;
            }
            aiSum += a * b;
            a2Sum += a * a;
            b2Sum += b * b;
        }
        return aiSum / (Math.sqrt(a2Sum) * Math.sqrt(b2Sum));
    }

    @Override
    public long[] getSize() {
        return this.size;
    }

    public static double getCosineSimilartiy(Matrix m1, Matrix m2, boolean ignoreNaN) {
        VerifyUtil.verifySameSize(m1, m2);
        double aiSum = 0.0;
        double a2Sum = 0.0;
        double b2Sum = 0.0;
        for (long[] c : m1.allCoordinates()) {
            double a = m1.getAsDouble(c);
            double b = m2.getAsDouble(c);
            if (ignoreNaN) {
                if (MathUtil.isNaNOrInfinite(a) || MathUtil.isNaNOrInfinite(b)) continue;
                aiSum += a * b;
                a2Sum += a * a;
                b2Sum += b * b;
                continue;
            }
            aiSum += a * b;
            a2Sum += a * a;
            b2Sum += b * b;
        }
        return aiSum / (Math.sqrt(a2Sum) * Math.sqrt(b2Sum));
    }
}

