package com.datumbox.framework.core.machinelearning.featureselection;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractScoreBasedFeatureSelector;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/featureselection/TFIDF.class */
public class TFIDF extends AbstractScoreBasedFeatureSelector<ModelParameters, TrainingParameters> {

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/featureselection/TFIDF$ModelParameters.class */
    public static class ModelParameters extends AbstractScoreBasedFeatureSelector.AbstractModelParameters {
        private static final long serialVersionUID = 2;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/featureselection/TFIDF$TrainingParameters.class */
    public static class TrainingParameters extends AbstractScoreBasedFeatureSelector.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
        private boolean binarized = false;

        public boolean isBinarized() {
            return this.binarized;
        }

        public void setBinarized(boolean z) {
            this.binarized = z;
        }
    }

    protected TFIDF(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected TFIDF(String str, Configuration configuration) {
        super(str, configuration);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector, com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer, com.datumbox.framework.core.machinelearning.common.interfaces.Trainable
    public void fit(Dataframe dataframe) {
        Set<TypeInference.DataType> supportedXDataTypes = getSupportedXDataTypes();
        Iterator<TypeInference.DataType> it = dataframe.getXDataTypes().values().iterator();
        while (it.hasNext()) {
            if (!supportedXDataTypes.contains(it.next())) {
                throw new IllegalArgumentException("A DataType that is not supported by this method was detected in the Dataframe.");
            }
        }
        super.fit(dataframe);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) this.knowledgeBase.getTrainingParameters();
        boolean isBinarized = trainingParameters.isBinarized();
        int size = dataframe.size();
        StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
        Map<Object, Double> bigMap = storageEngine.getBigMap("tmp_idf", Object.class, Double.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_MEMORY, true, true);
        Iterator<Record> it = dataframe.iterator();
        while (it.hasNext()) {
            for (Map.Entry<Object, Object> entry : it.next().getX().entrySet()) {
                Object key = entry.getKey();
                if (TypeInference.toDouble(entry.getValue()).doubleValue() > 0.0d) {
                    bigMap.put(key, Double.valueOf(bigMap.getOrDefault(key, Double.valueOf(0.0d)).doubleValue() + 1.0d));
                }
            }
        }
        Integer rareFeatureThreshold = trainingParameters.getRareFeatureThreshold();
        if (rareFeatureThreshold != null && rareFeatureThreshold.intValue() > 0) {
            removeRareFeatures(bigMap, rareFeatureThreshold.intValue());
        }
        this.streamExecutor.forEach(StreamMethods.stream(bigMap.entrySet().stream(), isParallelized()), entry2 -> {
            bigMap.put(entry2.getKey(), Double.valueOf(Math.log10(size / ((Double) entry2.getValue()).doubleValue())));
        });
        Map<Object, Double> featureScores = modelParameters.getFeatureScores();
        BiFunction biFunction = (obj, d) -> {
            Double d = (Double) featureScores.get(obj);
            return Boolean.valueOf(d == null || d.doubleValue() < d.doubleValue());
        };
        this.streamExecutor.forEach(StreamMethods.stream(dataframe.stream(), isParallelized()), record -> {
            for (Map.Entry<Object, Object> entry3 : record.getX().entrySet()) {
                Double d2 = TypeInference.toDouble(entry3.getValue());
                if (d2.doubleValue() > 0.0d) {
                    Object key2 = entry3.getKey();
                    double doubleValue = (isBinarized ? 1.0d : d2.doubleValue()) * ((Double) bigMap.getOrDefault(key2, Double.valueOf(0.0d))).doubleValue();
                    if (doubleValue > 0.0d && ((Boolean) biFunction.apply(key2, Double.valueOf(doubleValue))).booleanValue()) {
                        synchronized (featureScores) {
                            if (((Boolean) biFunction.apply(key2, Double.valueOf(doubleValue))).booleanValue()) {
                                featureScores.put(key2, Double.valueOf(doubleValue));
                            }
                        }
                    }
                }
            }
        });
        storageEngine.dropBigMap("tmp_idf", bigMap);
        Integer maxFeatures = trainingParameters.getMaxFeatures();
        if (maxFeatures == null || maxFeatures.intValue() >= featureScores.size()) {
            return;
        }
        keepTopFeatures(featureScores, maxFeatures.intValue());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector
    public Set<TypeInference.DataType> getSupportedXDataTypes() {
        return new HashSet(Arrays.asList(TypeInference.DataType.BOOLEAN, TypeInference.DataType.NUMERICAL));
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector
    protected Set<TypeInference.DataType> getSupportedYDataTypes() {
        return null;
    }
}
