package com.datumbox.framework.core.machinelearning.topicmodeling;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.AssociativeArray2D;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.storage.interfaces.BigMap;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.common.utilities.MapMethods;
import com.datumbox.framework.core.common.utilities.PHPMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractTopicModeler;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.core.statistics.sampling.SimpleRandomSampling;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/topicmodeling/LatentDirichletAllocation.class */
public class LatentDirichletAllocation extends AbstractTopicModeler<ModelParameters, TrainingParameters> {

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/topicmodeling/LatentDirichletAllocation$ModelParameters.class */
    public static class ModelParameters extends AbstractTrainer.AbstractModelParameters {
        private static final long serialVersionUID = 1;
        private Integer d;
        private int totalIterations;

        @BigMap(keyClass = List.class, valueClass = Integer.class, mapType = StorageEngine.MapType.HASHMAP, storageHint = StorageEngine.StorageHint.IN_CACHE, concurrent = false)
        private Map<List<Object>, Integer> topicAssignmentOfDocumentWord;

        @BigMap(keyClass = List.class, valueClass = Integer.class, mapType = StorageEngine.MapType.HASHMAP, storageHint = StorageEngine.StorageHint.IN_MEMORY, concurrent = false)
        private Map<List<Integer>, Integer> documentTopicCounts;

        @BigMap(keyClass = List.class, valueClass = Integer.class, mapType = StorageEngine.MapType.HASHMAP, storageHint = StorageEngine.StorageHint.IN_CACHE, concurrent = false)
        private Map<List<Object>, Integer> topicWordCounts;

        @BigMap(keyClass = Integer.class, valueClass = Integer.class, mapType = StorageEngine.MapType.HASHMAP, storageHint = StorageEngine.StorageHint.IN_MEMORY, concurrent = false)
        private Map<Integer, Integer> documentWordCounts;

        @BigMap(keyClass = Integer.class, valueClass = Integer.class, mapType = StorageEngine.MapType.HASHMAP, storageHint = StorageEngine.StorageHint.IN_MEMORY, concurrent = false)
        private Map<Integer, Integer> topicCounts;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
            this.d = 0;
        }

        public Integer getD() {
            return this.d;
        }

        protected void setD(Integer num) {
            this.d = num;
        }

        public int getTotalIterations() {
            return this.totalIterations;
        }

        protected void setTotalIterations(int i) {
            this.totalIterations = i;
        }

        public Map<List<Object>, Integer> getTopicAssignmentOfDocumentWord() {
            return this.topicAssignmentOfDocumentWord;
        }

        protected void setTopicAssignmentOfDocumentWord(Map<List<Object>, Integer> map) {
            this.topicAssignmentOfDocumentWord = map;
        }

        public Map<List<Integer>, Integer> getDocumentTopicCounts() {
            return this.documentTopicCounts;
        }

        protected void setDocumentTopicCounts(Map<List<Integer>, Integer> map) {
            this.documentTopicCounts = map;
        }

        public Map<List<Object>, Integer> getTopicWordCounts() {
            return this.topicWordCounts;
        }

        protected void setTopicWordCounts(Map<List<Object>, Integer> map) {
            this.topicWordCounts = map;
        }

        public Map<Integer, Integer> getDocumentWordCounts() {
            return this.documentWordCounts;
        }

        protected void setDocumentWordCounts(Map<Integer, Integer> map) {
            this.documentWordCounts = map;
        }

        public Map<Integer, Integer> getTopicCounts() {
            return this.topicCounts;
        }

        protected void setTopicCounts(Map<Integer, Integer> map) {
            this.topicCounts = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/topicmodeling/LatentDirichletAllocation$TrainingParameters.class */
    public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
        private int k = 2;
        private int maxIterations = 50;
        private double alpha = 1.0d;
        private double beta = 1.0d;

        public int getK() {
            return this.k;
        }

        public void setK(int i) {
            this.k = i;
        }

        public int getMaxIterations() {
            return this.maxIterations;
        }

        public void setMaxIterations(int i) {
            this.maxIterations = i;
        }

        public double getAlpha() {
            return this.alpha;
        }

        public void setAlpha(double d) {
            this.alpha = d;
        }

        public double getBeta() {
            return this.beta;
        }

        public void setBeta(double d) {
            this.beta = d;
        }
    }

    protected LatentDirichletAllocation(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected LatentDirichletAllocation(String str, Configuration configuration) {
        super(str, configuration);
    }

    public AssociativeArray2D getWordProbabilitiesPerTopic() {
        AssociativeArray2D associativeArray2D = new AssociativeArray2D();
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) this.knowledgeBase.getTrainingParameters();
        int k = trainingParameters.getK();
        for (int i = 0; i < k; i++) {
            associativeArray2D.put(Integer.valueOf(i), new AssociativeArray());
        }
        int intValue = modelParameters.getD().intValue();
        double beta = trainingParameters.getBeta();
        Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts();
        Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts();
        Iterator<Map.Entry<List<Object>, Integer>> it = topicWordCounts.entrySet().iterator();
        while (it.hasNext()) {
            List<Object> key = it.next().getKey();
            associativeArray2D.get((Integer) key.get(0)).put(key.get(1), Double.valueOf((r0.getValue().intValue() + beta) / (topicCounts.get(r0).intValue() + (beta * intValue))));
        }
        for (int i2 = 0; i2 < k; i2++) {
            associativeArray2D.put(Integer.valueOf(i2), MapMethods.sortAssociativeArrayByValueDescending(associativeArray2D.get(Integer.valueOf(i2))));
        }
        return associativeArray2D;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        double d;
        double d2;
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        modelParameters.setD(Integer.valueOf(dataframe.xColumnSize()));
        int intValue = modelParameters.getD().intValue();
        TrainingParameters trainingParameters = (TrainingParameters) this.knowledgeBase.getTrainingParameters();
        int k = trainingParameters.getK();
        Map<List<Object>, Integer> topicAssignmentOfDocumentWord = modelParameters.getTopicAssignmentOfDocumentWord();
        Map<List<Integer>, Integer> documentTopicCounts = modelParameters.getDocumentTopicCounts();
        Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts();
        Map<Integer, Integer> documentWordCounts = modelParameters.getDocumentWordCounts();
        Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts();
        for (Map.Entry<Integer, Record> entry : dataframe.entries()) {
            Integer key = entry.getKey();
            Record value = entry.getValue();
            documentWordCounts.put(key, Integer.valueOf(value.getX().size()));
            for (Map.Entry<Object, Object> entry2 : value.getX().entrySet()) {
                Object key2 = entry2.getKey();
                Object value2 = entry2.getValue();
                Integer valueOf = Integer.valueOf(PHPMethods.mt_rand(0, k - 1));
                increase(topicCounts, valueOf);
                topicAssignmentOfDocumentWord.put(Arrays.asList(key, key2), valueOf);
                increase(documentTopicCounts, Arrays.asList(key, valueOf));
                increase(topicWordCounts, Arrays.asList(valueOf, value2));
            }
        }
        double alpha = trainingParameters.getAlpha();
        double beta = trainingParameters.getBeta();
        int maxIterations = trainingParameters.getMaxIterations();
        int i = 0;
        while (i < maxIterations) {
            this.logger.debug("Iteration {}", Integer.valueOf(i));
            int i2 = 0;
            for (Map.Entry<Integer, Record> entry3 : dataframe.entries()) {
                Integer key3 = entry3.getKey();
                Record value3 = entry3.getValue();
                AssociativeArray associativeArray = new AssociativeArray();
                for (int i3 = 0; i3 < k; i3++) {
                    associativeArray.put(Integer.valueOf(i3), Double.valueOf(0.0d));
                }
                int size = value3.getX().size();
                for (Map.Entry<Object, Object> entry4 : value3.getX().entrySet()) {
                    Object key4 = entry4.getKey();
                    Object value4 = entry4.getValue();
                    Integer num = topicAssignmentOfDocumentWord.get(Arrays.asList(key3, key4));
                    decrease(topicCounts, num);
                    decrease(documentTopicCounts, Arrays.asList(key3, num));
                    decrease(topicWordCounts, Arrays.asList(num, value4));
                    AssociativeArray associativeArray2 = new AssociativeArray();
                    for (int i4 = 0; i4 < k; i4++) {
                        double intValue2 = topicWordCounts.get(Arrays.asList(Integer.valueOf(i4), value4)) != null ? r0.intValue() + beta : beta;
                        if (documentTopicCounts.get(Arrays.asList(key3, Integer.valueOf(i4))) != null) {
                            d = intValue2;
                            d2 = r0.intValue() + alpha;
                        } else {
                            d = intValue2;
                            d2 = alpha;
                        }
                        associativeArray2.put(Integer.valueOf(i4), Double.valueOf((d * d2) / (topicCounts.get(Integer.valueOf(i4)).intValue() + (beta * intValue))));
                    }
                    Integer num2 = (Integer) SimpleRandomSampling.weightedSampling(associativeArray2, 1, true).iterator().next();
                    topicAssignmentOfDocumentWord.put(Arrays.asList(key3, key4), num2);
                    increase(topicCounts, num2);
                    increase(documentTopicCounts, Arrays.asList(key3, num2));
                    increase(topicWordCounts, Arrays.asList(num2, value4));
                    associativeArray.put(num2, Double.valueOf(TypeInference.toDouble(associativeArray.get(num2)).doubleValue() + (1.0d / size)));
                }
                Object key5 = MapMethods.selectMaxKeyValue(associativeArray).getKey();
                if (!key5.equals(value3.getYPredicted())) {
                    i2++;
                }
                dataframe._unsafe_set(key3, new Record(value3.getX(), value3.getY(), key5, associativeArray));
            }
            i++;
            this.logger.debug("Reassigned Records {}", Integer.valueOf(i2));
            if (i2 == 0) {
                break;
            }
        }
        modelParameters.setTotalIterations(i);
    }

    private <K> void increase(Map<K, Integer> map, K k) {
        map.put(k, Integer.valueOf(map.getOrDefault(k, 0).intValue() + 1));
    }

    private <K> void decrease(Map<K, Integer> map, K k) {
        map.put(k, Integer.valueOf(map.getOrDefault(k, 0).intValue() - 1));
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    protected void _predict(Dataframe dataframe) {
        double d;
        double d2;
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) this.knowledgeBase.getTrainingParameters();
        int intValue = modelParameters.getD().intValue();
        int k = trainingParameters.getK();
        Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts();
        Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts();
        StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
        Map bigMap = storageEngine.getBigMap("tmp_topicAssignmentOfDocumentWord", List.class, Integer.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_CACHE, false, true);
        Map bigMap2 = storageEngine.getBigMap("tmp_documentTopicCounts", List.class, Integer.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_MEMORY, false, true);
        Map bigMap3 = storageEngine.getBigMap("tmp_topicWordCounts", List.class, Integer.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_CACHE, false, true);
        Map bigMap4 = storageEngine.getBigMap("tmp_topicCounts", Integer.class, Integer.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_MEMORY, false, true);
        for (Map.Entry<Integer, Record> entry : dataframe.entries()) {
            Integer key = entry.getKey();
            for (Map.Entry<Object, Object> entry2 : entry.getValue().getX().entrySet()) {
                Object key2 = entry2.getKey();
                Object value = entry2.getValue();
                Integer valueOf = Integer.valueOf(PHPMethods.mt_rand(0, k - 1));
                increase(bigMap4, valueOf);
                bigMap.put(Arrays.asList(key, key2), valueOf);
                increase(bigMap2, Arrays.asList(key, valueOf));
                increase(bigMap3, Arrays.asList(valueOf, value));
            }
        }
        double alpha = trainingParameters.getAlpha();
        double beta = trainingParameters.getBeta();
        int maxIterations = trainingParameters.getMaxIterations();
        for (int i = 0; i < maxIterations; i++) {
            this.logger.debug("Iteration {}", Integer.valueOf(i));
            int i2 = 0;
            double d3 = 0.0d;
            double d4 = 0.0d;
            for (Map.Entry<Integer, Record> entry3 : dataframe.entries()) {
                Integer key3 = entry3.getKey();
                Record value2 = entry3.getValue();
                AssociativeArray associativeArray = new AssociativeArray();
                for (int i3 = 0; i3 < k; i3++) {
                    associativeArray.put(Integer.valueOf(i3), Double.valueOf(0.0d));
                }
                int size = value2.getX().size();
                d4 += size;
                for (Map.Entry<Object, Object> entry4 : value2.getX().entrySet()) {
                    Object key4 = entry4.getKey();
                    Object value3 = entry4.getValue();
                    Integer num = (Integer) bigMap.get(Arrays.asList(key3, key4));
                    decrease(bigMap4, num);
                    decrease(bigMap2, Arrays.asList(key3, num));
                    decrease(bigMap3, Arrays.asList(num, value3));
                    int size2 = value2.getX().size() - 1;
                    AssociativeArray associativeArray2 = new AssociativeArray();
                    for (int i4 = 0; i4 < k; i4++) {
                        List asList = Arrays.asList(Integer.valueOf(i4), value3);
                        double intValue2 = ((Integer) bigMap3.get(asList)) != null ? r0.intValue() + beta : beta;
                        if (topicWordCounts.get(asList) != null) {
                            intValue2 += r0.intValue();
                        }
                        if (((Integer) bigMap2.get(Arrays.asList(key3, Integer.valueOf(i4)))) != null) {
                            d = intValue2;
                            d2 = r0.intValue() + alpha;
                        } else {
                            d = intValue2;
                            d2 = alpha;
                        }
                        associativeArray2.put(Integer.valueOf(i4), Double.valueOf((d * d2) / ((((((Integer) bigMap4.get(Integer.valueOf(i4))).intValue() + (beta * intValue)) - 1.0d) + topicCounts.get(Integer.valueOf(i4)).intValue()) * (size2 + (alpha * k)))));
                    }
                    d3 += Math.log(Descriptives.sum(associativeArray2.toFlatDataCollection()));
                    Integer num2 = (Integer) SimpleRandomSampling.weightedSampling(associativeArray2, 1, true).iterator().next();
                    bigMap.put(Arrays.asList(key3, key4), num2);
                    increase(bigMap4, num2);
                    increase(bigMap2, Arrays.asList(key3, num2));
                    increase(bigMap3, Arrays.asList(num2, value3));
                    associativeArray.put(num2, Double.valueOf(TypeInference.toDouble(associativeArray.get(num2)).doubleValue() + (1.0d / size)));
                }
                Object key5 = MapMethods.selectMaxKeyValue(associativeArray).getKey();
                if (!key5.equals(value2.getYPredicted())) {
                    i2++;
                }
                dataframe._unsafe_set(key3, new Record(value2.getX(), value2.getY(), key5, associativeArray));
            }
            this.logger.debug("Reassigned Records {} - Perplexity: {}", Integer.valueOf(i2), Double.valueOf(Math.exp((-d3) / d4)));
            if (i2 == 0) {
                break;
            }
        }
        storageEngine.dropBigMap("tmp_topicAssignmentOfDocumentWord", bigMap);
        storageEngine.dropBigMap("tmp_documentTopicCounts", bigMap2);
        storageEngine.dropBigMap("tmp_topicWordCounts", bigMap3);
        storageEngine.dropBigMap("tmp_topicCounts", bigMap4);
    }
}
