/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.clustering;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.DataframeMatrix;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM;
import com.datumbox.framework.core.statistics.distributions.ContinuousDistributions;
import java.util.Map;
import org.apache.commons.math3.linear.OpenMapRealVector;
import org.apache.commons.math3.linear.RealVector;

public class MultinomialDPMM
extends AbstractDPMM<Cluster, ModelParameters, TrainingParameters> {
    protected MultinomialDPMM(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected MultinomialDPMM(String storageName, Configuration configuration) {
        super(storageName, configuration);
    }

    @Override
    protected Cluster createNewCluster(Integer clusterId) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters)this.knowledgeBase.getTrainingParameters();
        Cluster c = new Cluster(clusterId, modelParameters.getD(), trainingParameters.getAlphaWords());
        c.setFeatureIds(modelParameters.getFeatureIds());
        return c;
    }

    public static class TrainingParameters
    extends AbstractDPMM.AbstractTrainingParameters {
        private static final long serialVersionUID = 1L;
        private double alphaWords = 50.0;

        public double getAlphaWords() {
            return this.alphaWords;
        }

        public void setAlphaWords(double alphaWords) {
            this.alphaWords = alphaWords;
        }
    }

    public static class ModelParameters
    extends AbstractDPMM.AbstractModelParameters<Cluster> {
        private static final long serialVersionUID = 1L;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }
    }

    public static class Cluster
    extends AbstractDPMM.AbstractCluster {
        private static final long serialVersionUID = 2L;
        private final double alphaWords;
        private RealVector wordCounts;
        private Double wordcounts_plusalpha;

        protected Cluster(Integer clusterId, int dimensions, double alphaWords) {
            super(clusterId);
            this.alphaWords = alphaWords;
            this.wordCounts = new OpenMapRealVector(dimensions);
            this.wordcounts_plusalpha = this.estimateWordCountsPlusAlpha();
        }

        @Override
        protected double posteriorLogPdf(Record r) {
            RealVector x_mu = DataframeMatrix.parseRecord(r, this.featureIds);
            RealVector wordCountsPlusAlpha = this.wordCounts.mapAdd(this.alphaWords);
            double logPdf = this.C(wordCountsPlusAlpha.add(x_mu)) - this.wordcounts_plusalpha;
            return logPdf;
        }

        @Override
        protected Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        @Override
        protected void setFeatureIds(Map<Object, Integer> featureIds) {
            this.featureIds = featureIds;
        }

        @Override
        protected void add(Record r) {
            RealVector rv = DataframeMatrix.parseRecord(r, this.featureIds);
            this.wordCounts = this.size == 0 ? rv : this.wordCounts.add(rv);
            ++this.size;
            this.updateClusterParameters();
        }

        @Override
        protected void remove(Record r) {
            --this.size;
            RealVector rv = DataframeMatrix.parseRecord(r, this.featureIds);
            this.wordCounts = this.wordCounts.subtract(rv);
            this.updateClusterParameters();
        }

        @Override
        protected void updateClusterParameters() {
            this.wordcounts_plusalpha = this.estimateWordCountsPlusAlpha();
        }

        @Override
        protected void clear() {
        }

        private double estimateWordCountsPlusAlpha() {
            RealVector wordCountsPlusAlpha = this.wordCounts.mapAdd(this.alphaWords);
            return this.C(wordCountsPlusAlpha);
        }

        private double C(RealVector alphaVector) {
            double sumAi = 0.0;
            double sumLogGammaAi = 0.0;
            int aLength = alphaVector.getDimension();
            for (int i = 0; i < aLength; ++i) {
                double tmp = alphaVector.getEntry(i);
                sumAi += tmp;
                sumLogGammaAi += ContinuousDistributions.logGamma(tmp);
            }
            double Cvalue = sumLogGammaAi - ContinuousDistributions.logGamma(sumAi);
            return Cvalue;
        }
    }
}

