package com.datumbox.framework.core.machinelearning.common.abstracts.algorithms;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.ForkJoinStream;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.storage.interfaces.BigMap;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractNaiveBayes.AbstractModelParameters;
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractNaiveBayes.AbstractTrainingParameters;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractNaiveBayes.class */
public abstract class AbstractNaiveBayes<MP extends AbstractModelParameters, TP extends AbstractTrainingParameters> extends AbstractClassifier<MP, TP> implements PredictParallelizable, TrainParallelizable {
    private boolean parallelized;
    protected final ForkJoinStream streamExecutor;

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractNaiveBayes$AbstractModelParameters.class */
    public static abstract class AbstractModelParameters extends AbstractClassifier.AbstractModelParameters {
        private Map<Object, Double> logPriors;

        @BigMap(keyClass = List.class, valueClass = Double.class, mapType = StorageEngine.MapType.HASHMAP, storageHint = StorageEngine.StorageHint.IN_MEMORY, concurrent = true)
        private Map<List<Object>, Double> logLikelihoods;

        /* JADX INFO: Access modifiers changed from: protected */
        public AbstractModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
            this.logPriors = new HashMap();
        }

        public Map<Object, Double> getLogPriors() {
            return this.logPriors;
        }

        protected void setLogPriors(Map<Object, Double> map) {
            this.logPriors = map;
        }

        public Map<List<Object>, Double> getLogLikelihoods() {
            return this.logLikelihoods;
        }

        protected void setLogLikelihoods(Map<List<Object>, Double> map) {
            this.logLikelihoods = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractNaiveBayes$AbstractTrainingParameters.class */
    public static abstract class AbstractTrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private boolean multiProbabilityWeighted = false;

        public boolean isMultiProbabilityWeighted() {
            return this.multiProbabilityWeighted;
        }

        public void setMultiProbabilityWeighted(boolean z) {
            this.multiProbabilityWeighted = z;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractNaiveBayes(TP tp, Configuration configuration) {
        super(tp, configuration);
        this.parallelized = true;
        this.streamExecutor = new ForkJoinStream(this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractNaiveBayes(String str, Configuration configuration) {
        super(str, configuration);
        this.parallelized = true;
        this.streamExecutor = new ForkJoinStream(this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    protected abstract boolean isBinarized();

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public void setParallelized(boolean z) {
        this.parallelized = z;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    protected void _predict(Dataframe dataframe) {
        _predictDatasetParallel(dataframe, this.knowledgeBase.getStorageEngine(), this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    public PredictParallelizable.Prediction _predictRecord(Record record) {
        AbstractModelParameters abstractModelParameters = (AbstractModelParameters) this.knowledgeBase.getModelParameters();
        Map<List<Object>, Double> logLikelihoods = abstractModelParameters.getLogLikelihoods();
        Map<Object, Double> logPriors = abstractModelParameters.getLogPriors();
        Set<Object> classes = abstractModelParameters.getClasses();
        Object next = classes.iterator().next();
        boolean isBinarized = isBinarized();
        AssociativeArray associativeArray = new AssociativeArray(new HashMap(logPriors));
        for (Map.Entry<Object, Object> entry : record.getX().entrySet()) {
            Object key = entry.getKey();
            if (logLikelihoods.containsKey(Arrays.asList(key, next))) {
                AssociativeArray associativeArray2 = new AssociativeArray();
                for (Object obj : classes) {
                    associativeArray2.put(obj, logLikelihoods.get(Arrays.asList(key, obj)));
                }
                Double d = TypeInference.toDouble(entry.getValue());
                if ((!((AbstractTrainingParameters) this.knowledgeBase.getTrainingParameters()).isMultiProbabilityWeighted() || isBinarized) && d.doubleValue() > 0.0d) {
                    d = Double.valueOf(1.0d);
                }
                for (Map.Entry<Object, Object> entry2 : associativeArray2.entrySet()) {
                    Object key2 = entry2.getKey();
                    associativeArray.put(key2, Double.valueOf(associativeArray.getDouble(key2).doubleValue() + (d.doubleValue() * TypeInference.toDouble(entry2.getValue()).doubleValue())));
                }
            }
        }
        Object selectedClassFromClassScores = getSelectedClassFromClassScores(associativeArray);
        Descriptives.normalizeExp(associativeArray);
        return new PredictParallelizable.Prediction(selectedClassFromClassScores, associativeArray);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    public void _fit(Dataframe dataframe) {
        AbstractModelParameters abstractModelParameters = (AbstractModelParameters) this.knowledgeBase.getModelParameters();
        int size = dataframe.size();
        int xColumnSize = dataframe.xColumnSize();
        Map<List<Object>, Double> logLikelihoods = abstractModelParameters.getLogLikelihoods();
        Map<Object, Double> logPriors = abstractModelParameters.getLogPriors();
        Set<Object> classes = abstractModelParameters.getClasses();
        boolean isBinarized = isBinarized();
        HashMap hashMap = new HashMap();
        Iterator<Record> it = dataframe.iterator();
        while (it.hasNext()) {
            Object y = it.next().getY();
            if (classes.add(y)) {
                logPriors.put(y, Double.valueOf(1.0d));
                hashMap.put(y, Double.valueOf(0.0d));
            } else {
                logPriors.put(y, Double.valueOf(logPriors.get(y).doubleValue() + 1.0d));
            }
        }
        this.streamExecutor.forEach(StreamMethods.stream(dataframe.getXDataTypes().keySet().stream(), isParallelized()), obj -> {
            Iterator it2 = classes.iterator();
            while (it2.hasNext()) {
                logLikelihoods.put(Arrays.asList(obj, it2.next()), Double.valueOf(0.0d));
            }
        });
        this.streamExecutor.forEach(StreamMethods.stream(dataframe.stream(), isParallelized()), record -> {
            Object y2 = record.getY();
            double d = 0.0d;
            for (Map.Entry<Object, Object> entry : record.getX().entrySet()) {
                Object key = entry.getKey();
                Double d2 = TypeInference.toDouble(entry.getValue());
                if (d2 != null && d2.doubleValue() > 0.0d) {
                    if (isBinarized) {
                        d2 = Double.valueOf(1.0d);
                    }
                    List asList = Arrays.asList(key, y2);
                    logLikelihoods.put(asList, Double.valueOf(((Double) logLikelihoods.get(asList)).doubleValue() + d2.doubleValue()));
                    d += d2.doubleValue();
                }
            }
            synchronized (hashMap) {
                hashMap.put(y2, Double.valueOf(((Double) hashMap.get(y2)).doubleValue() + d));
            }
        });
        for (Map.Entry<Object, Double> entry : logPriors.entrySet()) {
            logPriors.put(entry.getKey(), Double.valueOf(Math.log(entry.getValue().doubleValue() / size)));
        }
        this.streamExecutor.forEach(StreamMethods.stream(logLikelihoods.entrySet().stream(), isParallelized()), entry2 -> {
            List list = (List) entry2.getKey();
            Object obj2 = list.get(1);
            Double d = (Double) entry2.getValue();
            if (d == null) {
                d = Double.valueOf(0.0d);
            }
            logLikelihoods.put(list, Double.valueOf(Math.log(Double.valueOf((d.doubleValue() + 1.0d) / (((Double) hashMap.get(obj2)).doubleValue() + xColumnSize)).doubleValue())));
        });
    }
}
