package com.datumbox.framework.core.machinelearning.clustering;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.DataframeMatrix;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM;
import com.datumbox.framework.core.statistics.distributions.ContinuousDistributions;
import java.util.Map;
import org.apache.commons.math3.linear.OpenMapRealVector;
import org.apache.commons.math3.linear.RealVector;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/MultinomialDPMM.class */
public class MultinomialDPMM extends AbstractDPMM<Cluster, ModelParameters, TrainingParameters> {

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/MultinomialDPMM$Cluster.class */
    public static class Cluster extends AbstractDPMM.AbstractCluster {
        private static final long serialVersionUID = 2;
        private final double alphaWords;
        private RealVector wordCounts;
        private Double wordcounts_plusalpha;

        protected Cluster(Integer num, int i, double d) {
            super(num);
            this.alphaWords = d;
            this.wordCounts = new OpenMapRealVector(i);
            this.wordcounts_plusalpha = Double.valueOf(estimateWordCountsPlusAlpha());
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM.AbstractCluster
        protected double posteriorLogPdf(Record record) {
            return C(this.wordCounts.mapAdd(this.alphaWords).add(DataframeMatrix.parseRecord(record, this.featureIds))) - this.wordcounts_plusalpha.doubleValue();
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM.AbstractCluster
        protected Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM.AbstractCluster
        protected void setFeatureIds(Map<Object, Integer> map) {
            this.featureIds = map;
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM.AbstractCluster, com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer.AbstractCluster
        protected void add(Record record) {
            RealVector parseRecord = DataframeMatrix.parseRecord(record, this.featureIds);
            if (this.size == 0) {
                this.wordCounts = parseRecord;
            } else {
                this.wordCounts = this.wordCounts.add(parseRecord);
            }
            this.size++;
            updateClusterParameters();
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM.AbstractCluster, com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer.AbstractCluster
        protected void remove(Record record) {
            this.size--;
            this.wordCounts = this.wordCounts.subtract(DataframeMatrix.parseRecord(record, this.featureIds));
            updateClusterParameters();
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM.AbstractCluster
        protected void updateClusterParameters() {
            this.wordcounts_plusalpha = Double.valueOf(estimateWordCountsPlusAlpha());
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer.AbstractCluster
        protected void clear() {
        }

        private double estimateWordCountsPlusAlpha() {
            return C(this.wordCounts.mapAdd(this.alphaWords));
        }

        private double C(RealVector realVector) {
            double d = 0.0d;
            double d2 = 0.0d;
            int dimension = realVector.getDimension();
            for (int i = 0; i < dimension; i++) {
                double entry = realVector.getEntry(i);
                d += entry;
                d2 += ContinuousDistributions.logGamma(entry);
            }
            return d2 - ContinuousDistributions.logGamma(d);
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/MultinomialDPMM$ModelParameters.class */
    public static class ModelParameters extends AbstractDPMM.AbstractModelParameters<Cluster> {
        private static final long serialVersionUID = 1;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/MultinomialDPMM$TrainingParameters.class */
    public static class TrainingParameters extends AbstractDPMM.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
        private double alphaWords = 50.0d;

        public double getAlphaWords() {
            return this.alphaWords;
        }

        public void setAlphaWords(double d) {
            this.alphaWords = d;
        }
    }

    protected MultinomialDPMM(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected MultinomialDPMM(String str, Configuration configuration) {
        super(str, configuration);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM
    public Cluster createNewCluster(Integer num) {
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        Cluster cluster = new Cluster(num, modelParameters.getD().intValue(), ((TrainingParameters) this.knowledgeBase.getTrainingParameters()).getAlphaWords());
        cluster.setFeatureIds(modelParameters.getFeatureIds());
        return cluster;
    }
}
