package org.jdmp.core.algorithm.estimator;

import java.util.Iterator;
import org.jdmp.core.algorithm.classification.AbstractClassifier;
import org.jdmp.core.algorithm.classification.Classifier;
import org.jdmp.core.dataset.ListDataSet;
import org.jdmp.core.sample.Sample;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;

/* loaded from: input_file:org/jdmp/core/algorithm/estimator/MultivariateGaussianDensityEstimator.class */
public class MultivariateGaussianDensityEstimator extends AbstractClassifier {
    private static final long serialVersionUID = -8923432381344117225L;
    private Matrix covarianceMatrix = null;
    private Matrix meanMatrix = null;
    private Matrix inverse = null;
    private double factor = 0.0d;
    private int dimensions = 0;
    private int featureCount = 0;
    private int classCount = 0;

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public void trainAll(ListDataSet listDataSet) {
        this.featureCount = getFeatureCount(listDataSet);
        this.classCount = getClassCount(listDataSet);
        this.dimensions = this.featureCount + this.classCount;
        DenseMatrix zeros = Matrix.Factory.zeros(listDataSet.size(), this.dimensions);
        int i = 0;
        Iterator it = listDataSet.iterator();
        while (it.hasNext()) {
            Sample sample = (Sample) it.next();
            Matrix columnVector = sample.getAsMatrix(getInputLabel()).toColumnVector(Calculation.Ret.LINK);
            for (int i2 = 0; i2 < this.featureCount; i2++) {
                zeros.setAsDouble(columnVector.getAsDouble(new long[]{0, i2}), new long[]{i, i2});
            }
            Matrix columnVector2 = sample.getAsMatrix(getTargetLabel()).toColumnVector(Calculation.Ret.LINK);
            for (int i3 = 0; i3 < this.classCount; i3++) {
                zeros.setAsDouble(columnVector2.getAsDouble(new long[]{0, i3}), new long[]{i, i3 + this.featureCount});
            }
            i++;
        }
        this.meanMatrix = zeros.mean(Calculation.Ret.NEW, 0, true);
        this.covarianceMatrix = zeros.cov(Calculation.Ret.NEW, true, true);
        try {
            this.inverse = this.covarianceMatrix.inv();
            this.factor = 1.0d / Math.sqrt(this.covarianceMatrix.det() * Math.pow(6.283185307179586d, this.dimensions));
        } catch (Exception e) {
            this.inverse = this.covarianceMatrix.pinv();
            this.factor = 1.0d;
        }
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public void reset() {
        this.covarianceMatrix = null;
    }

    public double getDensity(Matrix matrix) {
        Matrix minus = matrix.minus(this.meanMatrix);
        return this.factor * Math.exp((-0.5d) * minus.mtimes(this.inverse).mtimes(minus.transpose()).doubleValue());
    }

    public double getDensityUnscaled(Matrix matrix) {
        Matrix minus = matrix.minus(this.meanMatrix);
        return Math.exp((-0.5d) * minus.mtimes(this.inverse).mtimes(minus.transpose()).doubleValue());
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public Matrix predictOne(Matrix matrix) {
        Matrix columnVector = matrix.toColumnVector(Calculation.Ret.NEW);
        DenseMatrix zeros = Matrix.Factory.zeros(1L, this.dimensions);
        for (int i = 0; i < this.featureCount; i++) {
            zeros.setAsDouble(columnVector.getAsDouble(new long[]{0, i}), new long[]{0, i});
        }
        DenseMatrix zeros2 = Matrix.Factory.zeros(1L, this.classCount);
        double d = 0.0d;
        for (int i2 = 0; i2 < this.classCount; i2++) {
            if (i2 > 0) {
                zeros.setAsDouble(0.0d, new long[]{0, (this.featureCount + i2) - 1});
            }
            zeros.setAsDouble(1.0d, new long[]{0, this.featureCount + i2});
            double density = getDensity(zeros);
            zeros2.setAsDouble(density, new long[]{0, i2});
            d += density;
        }
        for (int i3 = 0; i3 < this.classCount; i3++) {
            zeros2.setAsDouble(zeros2.getAsDouble(new long[]{0, i3}) / d, new long[]{0, i3});
        }
        return zeros2;
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public Classifier emptyCopy() {
        MultivariateGaussianDensityEstimator multivariateGaussianDensityEstimator = new MultivariateGaussianDensityEstimator();
        multivariateGaussianDensityEstimator.setInputLabel(getInputLabel());
        multivariateGaussianDensityEstimator.setTargetLabel(getTargetLabel());
        return multivariateGaussianDensityEstimator;
    }

    public static double getDensity(Matrix matrix, Matrix matrix2, Matrix matrix3) {
        Matrix minus = matrix.minus(matrix2);
        return (1.0d / Math.sqrt(matrix3.det() * Math.pow(6.283185307179586d, matrix.getColumnCount()))) * Math.exp((-0.5d) * minus.mtimes(matrix3.inv()).mtimes(minus.transpose()).doubleValue());
    }

    public static double getDensityUnscaled(Matrix matrix, Matrix matrix2, Matrix matrix3) {
        Matrix minus = matrix.minus(matrix2);
        return Math.exp((-0.5d) * minus.mtimes(matrix3.inv()).mtimes(minus.transpose()).doubleValue());
    }
}
