package jsat.linear;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.neuralnetwork.SOM;
import jsat.distributions.ChiSquared;
import jsat.linear.distancemetrics.MahalanobisDistance;
import jsat.utils.IndexTable;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;
import jsat.utils.Tuple3;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/linear/MatrixStatistics.class */
public class MatrixStatistics {
    private MatrixStatistics() {
    }

    public static <V extends Vec> Vec meanVector(List<V> list) {
        if (list.isEmpty()) {
            throw new ArithmeticException("Can not compute the mean of zero data points");
        }
        DenseVector denseVector = new DenseVector(list.get(0).length());
        meanVector(denseVector, list);
        return denseVector;
    }

    public static Vec meanVector(DataSet dataSet) {
        DenseVector denseVector = new DenseVector(dataSet.getNumNumericalVars());
        meanVector(denseVector, dataSet);
        return denseVector;
    }

    public static <V extends Vec> void meanVector(Vec vec, List<V> list) {
        if (list.isEmpty()) {
            throw new ArithmeticException("Can not compute the mean of zero data points");
        }
        if (list.get(0).length() != vec.length()) {
            throw new ArithmeticException("Vector dimensions do not agree");
        }
        Iterator<V> it = list.iterator();
        while (it.hasNext()) {
            vec.mutableAdd(it.next());
        }
        vec.mutableDivide(list.size());
    }

    public static <V extends Vec> void meanVector(Vec vec, List<V> list, Collection<Integer> collection) {
        if (list.isEmpty()) {
            throw new ArithmeticException("Can not compute the mean of zero data points");
        }
        if (list.get(0).length() != vec.length()) {
            throw new ArithmeticException("Vector dimensions do not agree");
        }
        Iterator<Integer> it = collection.iterator();
        while (it.hasNext()) {
            vec.mutableAdd(list.get(it.next().intValue()));
        }
        vec.mutableDivide(collection.size());
    }

    public static void meanVector(Vec vec, DataSet dataSet) {
        if (dataSet.getSampleSize() == 0) {
            throw new ArithmeticException("Can not compute the mean of zero data points");
        }
        double d = 0.0d;
        for (int i = 0; i < dataSet.getSampleSize(); i++) {
            DataPoint dataPoint = dataSet.getDataPoint(i);
            double weight = dataPoint.getWeight();
            d += weight;
            vec.mutableAdd(weight, dataPoint.getNumericalValues());
        }
        vec.mutableDivide(d);
    }

    public static <V extends Vec> Matrix covarianceMatrix(Vec vec, List<V> list) {
        DenseMatrix denseMatrix = new DenseMatrix(vec.length(), vec.length());
        covarianceMatrix(vec, denseMatrix, list);
        return denseMatrix;
    }

    public static <V extends Vec> void covarianceMatrix(Vec vec, Matrix matrix, List<V> list) {
        if (!matrix.isSquare()) {
            throw new ArithmeticException("Storage for covariance matrix must be square");
        }
        if (matrix.rows() != vec.length()) {
            throw new ArithmeticException("Covariance Matrix size and mean size do not agree");
        }
        if (list.isEmpty()) {
            throw new ArithmeticException("No data points to compute covariance from");
        }
        if (vec.length() != list.get(0).length()) {
            throw new ArithmeticException("Data vectors do not agree with mean and covariance matrix");
        }
        DenseVector denseVector = new DenseVector(vec.length());
        Iterator<V> it = list.iterator();
        while (it.hasNext()) {
            it.next().copyTo(denseVector);
            denseVector.mutableSubtract(vec);
            Matrix.OuterProductUpdate(matrix, denseVector, denseVector, 1.0d);
        }
        matrix.mutableMultiply(1.0d / (list.size() - 1.0d));
    }

    public static <V extends Vec> void covarianceMatrix(Vec vec, Matrix matrix, List<V> list, Collection<Integer> collection) {
        if (!matrix.isSquare()) {
            throw new ArithmeticException("Storage for covariance matrix must be square");
        }
        if (matrix.rows() != vec.length()) {
            throw new ArithmeticException("Covariance Matrix size and mean size do not agree");
        }
        if (list.isEmpty()) {
            throw new ArithmeticException("No data points to compute covariance from");
        }
        if (vec.length() != list.get(0).length()) {
            throw new ArithmeticException("Data vectors do not agree with mean and covariance matrix");
        }
        DenseVector denseVector = new DenseVector(vec.length());
        Iterator<Integer> it = collection.iterator();
        while (it.hasNext()) {
            list.get(it.next().intValue()).copyTo(denseVector);
            denseVector.mutableSubtract(vec);
            Matrix.OuterProductUpdate(matrix, denseVector, denseVector, 1.0d);
        }
        matrix.mutableMultiply(1.0d / (collection.size() - 1.0d));
    }

    public static void covarianceMatrix(Vec vec, List<DataPoint> list, Matrix matrix) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (DataPoint dataPoint : list) {
            d += dataPoint.getWeight();
            d2 += Math.pow(dataPoint.getWeight(), 2.0d);
        }
        covarianceMatrix(vec, list, matrix, d, d2);
    }

    public static void covarianceMatrix(Vec vec, List<DataPoint> list, Matrix matrix, double d, double d2) {
        if (!matrix.isSquare()) {
            throw new ArithmeticException("Storage for covariance matrix must be square");
        }
        if (matrix.rows() != vec.length()) {
            throw new ArithmeticException("Covariance Matrix size and mean size do not agree");
        }
        if (list.isEmpty()) {
            throw new ArithmeticException("No data points to compute covariance from");
        }
        if (vec.length() != list.get(0).getNumericalValues().length()) {
            throw new ArithmeticException("Data vectors do not agree with mean and covariance matrix");
        }
        DenseVector denseVector = new DenseVector(vec.length());
        for (int i = 0; i < list.size(); i++) {
            DataPoint dataPoint = list.get(i);
            dataPoint.getNumericalValues().copyTo(denseVector);
            denseVector.mutableSubtract(vec);
            Matrix.OuterProductUpdate(matrix, denseVector, denseVector, dataPoint.getWeight());
        }
        matrix.mutableMultiply(d / (Math.pow(d, 2.0d) - d2));
    }

    public static Matrix covarianceMatrix(Vec vec, DataSet dataSet) {
        DenseMatrix denseMatrix = new DenseMatrix(vec.length(), vec.length());
        covarianceMatrix(vec, dataSet, denseMatrix);
        return denseMatrix;
    }

    public static void covarianceMatrix(Vec vec, DataSet dataSet, Matrix matrix) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < dataSet.getSampleSize(); i++) {
            DataPoint dataPoint = dataSet.getDataPoint(i);
            d += dataPoint.getWeight();
            d2 += Math.pow(dataPoint.getWeight(), 2.0d);
        }
        covarianceMatrix(vec, dataSet, matrix, d, d2);
    }

    public static void covarianceMatrix(Vec vec, DataSet dataSet, Matrix matrix, double d, double d2) {
        if (!matrix.isSquare()) {
            throw new ArithmeticException("Storage for covariance matrix must be square");
        }
        if (matrix.rows() != vec.length()) {
            throw new ArithmeticException("Covariance Matrix size and mean size do not agree");
        }
        if (dataSet.getSampleSize() == 0) {
            throw new ArithmeticException("No data points to compute covariance from");
        }
        if (vec.length() != dataSet.getNumNumericalVars()) {
            throw new ArithmeticException("Data vectors do not agree with mean and covariance matrix");
        }
        DenseVector denseVector = new DenseVector(vec.length());
        for (int i = 0; i < dataSet.getSampleSize(); i++) {
            DataPoint dataPoint = dataSet.getDataPoint(i);
            dataPoint.getNumericalValues().copyTo(denseVector);
            denseVector.mutableSubtract(vec);
            Matrix.OuterProductUpdate(matrix, denseVector, denseVector, dataPoint.getWeight());
        }
        matrix.mutableMultiply(d / (Math.pow(d, 2.0d) - d2));
    }

    public static void covarianceDiag(Vec vec, Vec vec2, DataSet dataSet) {
        int sampleSize = dataSet.getSampleSize();
        int[] iArr = new int[dataSet.getNumNumericalVars()];
        double d = 0.0d;
        for (int i = 0; i < sampleSize; i++) {
            double weight = dataSet.getDataPoint(i).getWeight();
            d += weight;
            Iterator<IndexValue> it = dataSet.getDataPoint(i).getNumericalValues().iterator();
            while (it.hasNext()) {
                IndexValue next = it.next();
                int index = next.getIndex();
                iArr[index] = iArr[index] + 1;
                vec2.increment(index, weight * Math.pow(next.getValue() - vec.get(index), 2.0d));
            }
        }
        for (int i2 = 0; i2 < iArr.length; i2++) {
            vec2.increment(i2, Math.pow(vec.get(i2), 2.0d) * (sampleSize - iArr[i2]));
        }
        vec2.mutableDivide(d);
    }

    public static Vec covarianceDiag(Vec vec, DataSet dataSet) {
        DenseVector denseVector = new DenseVector(dataSet.getNumNumericalVars());
        covarianceDiag(vec, denseVector, dataSet);
        return denseVector;
    }

    public static <V extends Vec> void covarianceDiag(Vec vec, Vec vec2, List<V> list) {
        int size = list.size();
        int[] iArr = new int[list.get(0).length()];
        for (int i = 0; i < size; i++) {
            Iterator<IndexValue> it = list.get(i).iterator();
            while (it.hasNext()) {
                IndexValue next = it.next();
                int index = next.getIndex();
                iArr[index] = iArr[index] + 1;
                vec2.increment(index, Math.pow(next.getValue() - vec.get(index), 2.0d));
            }
        }
        for (int i2 = 0; i2 < iArr.length; i2++) {
            vec2.increment(i2, Math.pow(vec.get(i2), 2.0d) * (size - iArr[i2]));
        }
        vec2.mutableDivide(size);
    }

    public static <V extends Vec> Vec covarianceDiag(Vec vec, List<V> list) {
        DenseVector denseVector = new DenseVector(list.get(0).length());
        covarianceDiag(vec, denseVector, list);
        return denseVector;
    }

    public static <V extends Vec> void FastMCD(Vec vec, Matrix matrix, List<V> list, boolean z) {
        int size = list.size();
        int length = list.get(0).length();
        int ceil = (int) Math.ceil(((size + length) + 1) / 2.0d);
        vec.zeroOut();
        matrix.zeroOut();
        if (ceil == size) {
            meanVector(vec, list);
            covarianceMatrix(vec, matrix, list);
            return;
        }
        double d = Double.POSITIVE_INFINITY;
        Vec vec2 = null;
        Matrix matrix2 = null;
        if (size <= 600) {
            for (Tuple3 tuple3 : (List) ParallelUtils.range(SOM.DEFAULT_MAX_ITERS, z).mapToObj(i -> {
                Random random = RandomUtil.getRandom(i);
                Vec mo46clone = vec.mo46clone();
                Matrix mo171clone = matrix.mo171clone();
                IntList range = ListUtils.range(0, size);
                Collections.shuffle(range, random);
                IntList intList = new IntList(range.subList(0, length + 1));
                meanVector(mo46clone, list, intList);
                covarianceMatrix(mo46clone, mo171clone, list, intList);
                double d2 = 0.0d;
                for (int i = 0; i < 3; i++) {
                    d2 = MCD_C_step(mo46clone, mo171clone, list, intList, ceil, false);
                }
                return new Tuple3(Double.valueOf(d2), mo46clone, mo171clone);
            }).sorted((tuple32, tuple33) -> {
                return Double.compare(((Double) tuple32.getX()).doubleValue(), ((Double) tuple33.getX()).doubleValue());
            }).limit(10L).collect(Collectors.toList())) {
                double doubleValue = ((Double) tuple3.getX()).doubleValue();
                IntList intList = new IntList(ceil);
                Vec vec3 = (Vec) tuple3.getY();
                Matrix matrix3 = (Matrix) tuple3.getZ();
                for (int i2 = 0; i2 < 20; i2++) {
                    double MCD_C_step = MCD_C_step(vec3, matrix3, list, intList, ceil, z);
                    if (Math.abs(MCD_C_step - doubleValue) < 1.0E-9d) {
                        break;
                    }
                    doubleValue = MCD_C_step;
                }
                if (doubleValue < d) {
                    matrix2 = matrix3;
                    vec2 = vec3;
                    d = doubleValue;
                }
            }
        } else {
            int floor = size >= 1500 ? 5 : (int) Math.floor(size / 300.0d);
            IntList range = ListUtils.range(0, size);
            Collections.shuffle(range, RandomUtil.getLocalRandom());
            IntList[] intListArr = new IntList[floor];
            for (int i3 = 0; i3 < floor; i3++) {
                intListArr[i3] = new IntList();
            }
            for (int i4 = 0; i4 < Math.min(1500, range.size()); i4++) {
                intListArr[i4 % intListArr.length].add(range.get(i4));
            }
            int size2 = (intListArr[0].size() * ceil) / size;
            List list2 = (List) Arrays.asList(intListArr).stream().flatMap(intList2 -> {
                return ParallelUtils.range(100, z).mapToObj(i5 -> {
                    Random random = RandomUtil.getRandom(i5);
                    Vec mo46clone = vec.mo46clone();
                    Matrix mo171clone = matrix.mo171clone();
                    IntList intList2 = new IntList(intList2);
                    Collections.shuffle(intList2, random);
                    IntList intList3 = new IntList(intList2.subList(0, length + 1));
                    meanVector(mo46clone, list, intList3);
                    covarianceMatrix(mo46clone, mo171clone, list, intList3);
                    double d2 = 0.0d;
                    for (int i5 = 0; i5 < 3; i5++) {
                        d2 = MCD_C_step(mo46clone, mo171clone, list, intList3, size2, false);
                    }
                    return new Tuple3(Double.valueOf(d2), mo46clone, mo171clone);
                }).sorted((tuple34, tuple35) -> {
                    return Double.compare(((Double) tuple34.getX()).doubleValue(), ((Double) tuple35.getX()).doubleValue());
                }).limit(10L);
            }).collect(Collectors.toList());
            IntSet intSet = new IntSet();
            for (IntList intList3 : intListArr) {
                intSet.addAll(intList3);
            }
            int size3 = (intSet.size() * ceil) / size;
            for (Tuple3 tuple34 : (List) list2.parallelStream().map(tuple35 -> {
                Vec vec4 = (Vec) tuple35.getY();
                Matrix matrix4 = (Matrix) tuple35.getZ();
                IntList intList4 = new IntList();
                double d2 = 0.0d;
                for (int i5 = 0; i5 < 3; i5++) {
                    d2 = MCD_C_step(vec4, matrix4, list, intList4, size3, false);
                }
                return new Tuple3(Double.valueOf(d2), vec4, matrix4);
            }).sorted((tuple36, tuple37) -> {
                return Double.compare(((Double) tuple36.getX()).doubleValue(), ((Double) tuple37.getX()).doubleValue());
            }).limit(10L).collect(Collectors.toList())) {
                double doubleValue2 = ((Double) tuple34.getX()).doubleValue();
                IntList intList4 = new IntList(ceil);
                Vec vec4 = (Vec) tuple34.getY();
                Matrix matrix4 = (Matrix) tuple34.getZ();
                for (int i5 = 0; i5 < 20; i5++) {
                    double MCD_C_step2 = MCD_C_step(vec4, matrix4, list, intList4, ceil, z);
                    if (Math.abs(MCD_C_step2 - doubleValue2) < 1.0E-9d) {
                        break;
                    }
                    doubleValue2 = MCD_C_step2;
                }
                if (doubleValue2 < d) {
                    matrix2 = matrix4;
                    vec2 = vec4;
                    d = doubleValue2;
                }
            }
        }
        Vec vec5 = vec2;
        Matrix matrix5 = matrix2;
        MahalanobisDistance mahalanobisDistance = new MahalanobisDistance();
        mahalanobisDistance.setInverseCovariance(new LUPDecomposition(matrix5.mo171clone()).solve(Matrix.eye(matrix5.cols())));
        ChiSquared chiSquared = new ChiSquared(matrix5.cols());
        double[] dArr = new double[size];
        ParallelUtils.run(z, size, (i6, i7) -> {
            for (int i6 = i6; i6 < i7; i6++) {
                dArr[i6] = mahalanobisDistance.dist(vec5, (Vec) list.get(i6));
            }
        });
        double pow = Math.pow(dArr[new IndexTable(dArr).index(size / 2)], 2.0d) / chiSquared.invCdf(0.5d);
        matrix5.mutableMultiply(pow);
        for (int i8 = 0; i8 < size; i8++) {
            int i9 = i8;
            dArr[i9] = dArr[i9] / pow;
        }
        double sqrt = Math.sqrt(chiSquared.invCdf(0.975d));
        ArrayList arrayList = new ArrayList(size);
        for (int i10 = 0; i10 < size; i10++) {
            if (dArr[i10] <= sqrt) {
                arrayList.add(list.get(i10));
            }
        }
        vec.zeroOut();
        meanVector(vec, arrayList);
        matrix.zeroOut();
        covarianceMatrix(vec, matrix, arrayList);
    }

    protected static <V extends Vec> double MCD_C_step(Vec vec, Matrix matrix, List<V> list, IntList intList, int i, boolean z) {
        int size = list.size();
        MahalanobisDistance mahalanobisDistance = new MahalanobisDistance();
        for (int i2 = 0; i2 < matrix.rows(); i2++) {
            matrix.increment(i2, i2, 1.0E-4d);
        }
        LUPDecomposition lUPDecomposition = new LUPDecomposition(matrix.mo171clone());
        mahalanobisDistance.setInverseCovariance(lUPDecomposition.solve(Matrix.eye(matrix.cols())));
        double[] dArr = new double[size];
        for (int i3 = 0; i3 < size; i3++) {
            dArr[i3] = mahalanobisDistance.dist(vec, list.get(i3));
        }
        IndexTable indexTable = new IndexTable(dArr);
        intList.clear();
        for (int i4 = 0; i4 < i; i4++) {
            intList.add(indexTable.index(i4));
        }
        meanVector(vec, list, intList);
        covarianceMatrix(vec, matrix, list, intList);
        return lUPDecomposition.det();
    }
}
