package com.datumbox.framework.core.machinelearning.clustering;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.DataframeMatrix;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM;
import java.util.Map;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.OpenMapRealMatrix;
import org.apache.commons.math3.linear.OpenMapRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/GaussianDPMM.class */
public class GaussianDPMM extends AbstractDPMM<Cluster, ModelParameters, TrainingParameters> {

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/GaussianDPMM$Cluster.class */
    public static class Cluster extends AbstractDPMM.AbstractCluster {
        private static final long serialVersionUID = 2;
        private final int dimensions;
        private final int kappa0;
        private final int nu0;
        private final RealVector mu0;
        private final RealMatrix psi0;
        private RealVector mean;
        private RealMatrix covariance;
        private RealMatrix meanError;
        private int meanDf;
        private RealVector xi_sum;
        private RealMatrix xi_square_sum;
        private volatile Double cache_covariance_determinant;
        private volatile Array2DRowRealMatrix cache_covariance_inverse;

        protected Cluster(Integer num, int i, int i2, int i3, RealVector realVector, RealMatrix realMatrix) {
            super(num);
            i3 = i3 < i ? i : i3;
            this.mean = new OpenMapRealVector(i);
            this.covariance = new OpenMapRealMatrix(i, i);
            for (int i4 = 0; i4 < i; i4++) {
                this.covariance.setEntry(i4, i4, 1.0d);
            }
            this.meanError = calculateMeanError(realMatrix, i2, i3);
            this.meanDf = (i3 - i) + 1;
            this.kappa0 = i2;
            this.nu0 = i3;
            this.mu0 = new OpenMapRealVector(realVector);
            this.psi0 = new OpenMapRealMatrix(i, i).add(realMatrix);
            this.dimensions = i;
            this.xi_sum = new OpenMapRealVector(i);
            this.xi_square_sum = new OpenMapRealMatrix(i, i);
            this.cache_covariance_determinant = null;
            this.cache_covariance_inverse = null;
        }

        private void assertModifiable() {
            if (this.xi_sum == null || this.xi_square_sum == null) {
                throw new RuntimeException("The cluster parameters are already estimated.");
            }
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM.AbstractCluster
        protected Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM.AbstractCluster
        protected void setFeatureIds(Map<Object, Integer> map) {
            this.featureIds = map;
        }

        protected RealMatrix getMeanError() {
            return this.meanError;
        }

        protected int getMeanDf() {
            return this.meanDf;
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM.AbstractCluster
        protected double posteriorLogPdf(Record record) {
            RealVector subtract = DataframeMatrix.parseRecord(record, this.featureIds).subtract(this.mean);
            if (this.cache_covariance_determinant == null || this.cache_covariance_inverse == null) {
                synchronized (this) {
                    if (this.cache_covariance_determinant == null || this.cache_covariance_inverse == null) {
                        LUDecomposition lUDecomposition = new LUDecomposition(this.covariance);
                        this.cache_covariance_determinant = Double.valueOf(lUDecomposition.getDeterminant());
                        this.cache_covariance_inverse = lUDecomposition.getSolver().getInverse();
                    }
                }
            }
            return ((-0.5d) * this.cache_covariance_inverse.preMultiply(subtract).dotProduct(subtract)) + Math.log(1.0d / (Math.pow(6.283185307179586d, this.dimensions / 2.0d) * Math.pow(this.cache_covariance_determinant.doubleValue(), 0.5d)));
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM.AbstractCluster, com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer.AbstractCluster
        protected void add(Record record) {
            assertModifiable();
            RealVector parseRecord = DataframeMatrix.parseRecord(record, this.featureIds);
            this.xi_sum = this.xi_sum.add(parseRecord);
            this.xi_square_sum = this.xi_square_sum.add(parseRecord.outerProduct(parseRecord));
            this.size++;
            updateClusterParameters();
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM.AbstractCluster, com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer.AbstractCluster
        protected void remove(Record record) {
            assertModifiable();
            if (this.size == 0) {
                throw new IllegalArgumentException("The cluster is empty.");
            }
            this.size--;
            RealVector parseRecord = DataframeMatrix.parseRecord(record, this.featureIds);
            this.xi_sum = this.xi_sum.subtract(parseRecord);
            this.xi_square_sum = this.xi_square_sum.subtract(parseRecord.outerProduct(parseRecord));
            updateClusterParameters();
        }

        private RealMatrix calculateMeanError(RealMatrix realMatrix, int i, int i2) {
            return realMatrix.scalarMultiply(1.0d / (i * ((i2 - this.dimensions) + 1.0d)));
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer.AbstractCluster
        protected void clear() {
            this.xi_sum = null;
            this.xi_square_sum = null;
            this.cache_covariance_determinant = null;
            this.cache_covariance_inverse = null;
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM.AbstractCluster
        protected void updateClusterParameters() {
            assertModifiable();
            int i = this.kappa0 + this.size;
            int i2 = this.nu0 + this.size;
            RealVector mapDivide = this.xi_sum.mapDivide(this.size);
            RealVector subtract = mapDivide.subtract(this.mu0);
            RealMatrix add = this.psi0.add(this.xi_square_sum.subtract(mapDivide.outerProduct(mapDivide).scalarMultiply(this.size)).add(subtract.outerProduct(subtract).scalarMultiply((this.kappa0 * this.size) / i)));
            this.mean = this.mu0.mapMultiply(this.kappa0).add(mapDivide.mapMultiply(this.size)).mapDivide(i);
            synchronized (this) {
                this.covariance = add.scalarMultiply((i + 1.0d) / (i * ((i2 - this.dimensions) + 1.0d)));
                this.cache_covariance_determinant = null;
                this.cache_covariance_inverse = null;
            }
            this.meanError = calculateMeanError(add, i, i2);
            this.meanDf = (i2 - this.dimensions) + 1;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/GaussianDPMM$ModelParameters.class */
    public static class ModelParameters extends AbstractDPMM.AbstractModelParameters<Cluster> {
        private static final long serialVersionUID = 1;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/GaussianDPMM$TrainingParameters.class */
    public static class TrainingParameters extends AbstractDPMM.AbstractTrainingParameters {
        private static final long serialVersionUID = 2;
        private int kappa0 = 0;
        private int nu0 = 1;
        private RealVector mu0;
        private RealMatrix psi0;

        public int getKappa0() {
            return this.kappa0;
        }

        public void setKappa0(int i) {
            this.kappa0 = i;
        }

        public int getNu0() {
            return this.nu0;
        }

        public void setNu0(int i) {
            this.nu0 = i;
        }

        public RealVector getMu0() {
            return this.mu0;
        }

        public void setMu0(RealVector realVector) {
            this.mu0 = realVector;
        }

        public RealMatrix getPsi0() {
            return this.psi0;
        }

        public void setPsi0(RealMatrix realMatrix) {
            this.psi0 = realMatrix;
        }
    }

    protected GaussianDPMM(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected GaussianDPMM(String str, Configuration configuration) {
        super(str, configuration);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM
    public Cluster createNewCluster(Integer num) {
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) this.knowledgeBase.getTrainingParameters();
        Cluster cluster = new Cluster(num, modelParameters.getD().intValue(), trainingParameters.getKappa0(), trainingParameters.getNu0(), trainingParameters.getMu0(), trainingParameters.getPsi0());
        cluster.setFeatureIds(modelParameters.getFeatureIds());
        return cluster;
    }
}
