package com.datumbox.framework.core.machinelearning.common.abstracts.algorithms;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.DataTable2D;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.common.utilities.MapMethods;
import com.datumbox.framework.core.machinelearning.MLBuilder;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractBoostingBagging.AbstractModelParameters;
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractBoostingBagging.AbstractTrainingParameters;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.dataobjects.TrainableBundle;
import com.datumbox.framework.core.machinelearning.ensemblelearning.FixedCombinationRules;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.core.statistics.sampling.SimpleRandomSampling;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractBoostingBagging.class */
public abstract class AbstractBoostingBagging<MP extends AbstractModelParameters, TP extends AbstractTrainingParameters> extends AbstractClassifier<MP, TP> {
    private final TrainableBundle bundle;
    private static final String STORAGE_INDICATOR = "Cmp";
    private static final int MAX_NUM_OF_RETRIES = 2;

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractBoostingBagging$AbstractModelParameters.class */
    public static abstract class AbstractModelParameters extends AbstractClassifier.AbstractModelParameters {
        private List<Double> weakClassifierWeights;

        /* JADX INFO: Access modifiers changed from: protected */
        public AbstractModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
            this.weakClassifierWeights = new ArrayList();
        }

        public List<Double> getWeakClassifierWeights() {
            return this.weakClassifierWeights;
        }

        protected void setWeakClassifierWeights(List<Double> list) {
            this.weakClassifierWeights = list;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractBoostingBagging$AbstractTrainingParameters.class */
    public static abstract class AbstractTrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private int maxWeakClassifiers = 5;
        private AbstractTrainer.AbstractTrainingParameters weakClassifierTrainingParameters;

        public int getMaxWeakClassifiers() {
            return this.maxWeakClassifiers;
        }

        public void setMaxWeakClassifiers(int i) {
            this.maxWeakClassifiers = i;
        }

        public AbstractTrainer.AbstractTrainingParameters getWeakClassifierTrainingParameters() {
            return this.weakClassifierTrainingParameters;
        }

        public void setWeakClassifierTrainingParameters(AbstractTrainer.AbstractTrainingParameters abstractTrainingParameters) {
            this.weakClassifierTrainingParameters = abstractTrainingParameters;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractBoostingBagging$Status.class */
    protected enum Status {
        NEXT,
        STOP,
        IGNORE
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractBoostingBagging(TP tp, Configuration configuration) {
        super(tp, configuration);
        this.bundle = new TrainableBundle(configuration.getStorageConfiguration().getStorageNameSeparator());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractBoostingBagging(String str, Configuration configuration) {
        super(str, configuration);
        this.bundle = new TrainableBundle(configuration.getStorageConfiguration().getStorageNameSeparator());
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    protected void _predict(Dataframe dataframe) {
        initBundle();
        List<Double> weakClassifierWeights = ((AbstractModelParameters) this.knowledgeBase.getModelParameters()).getWeakClassifierWeights();
        StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
        Map bigMap = storageEngine.getBigMap("tmp_recordDecisions", Object.class, DataTable2D.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_DISK, false, true);
        Iterator<Integer> it = dataframe.index().iterator();
        while (it.hasNext()) {
            bigMap.put(it.next(), new DataTable2D());
        }
        AssociativeArray associativeArray = new AssociativeArray();
        int size = weakClassifierWeights.size();
        for (int i = 0; i < size; i++) {
            ((AbstractClassifier) this.bundle.get(STORAGE_INDICATOR + i)).predict(dataframe);
            associativeArray.put(Integer.valueOf(i), weakClassifierWeights.get(i));
            for (Map.Entry<Integer, Record> entry : dataframe.entries()) {
                Integer key = entry.getKey();
                AssociativeArray yPredictedProbabilities = entry.getValue().getYPredictedProbabilities();
                DataTable2D dataTable2D = (DataTable2D) bigMap.get(key);
                dataTable2D.put(Integer.valueOf(i), yPredictedProbabilities);
                bigMap.put(key, dataTable2D);
            }
        }
        for (Map.Entry<Integer, Record> entry2 : dataframe.entries()) {
            Integer key2 = entry2.getKey();
            Record value = entry2.getValue();
            AssociativeArray weightedAverage = FixedCombinationRules.weightedAverage((DataTable2D) bigMap.get(key2), associativeArray);
            Descriptives.normalize(weightedAverage);
            dataframe._unsafe_set(key2, new Record(value.getX(), value.getY(), MapMethods.selectMaxKeyValue(weightedAverage).getKey(), weightedAverage));
        }
        storageEngine.dropBigMap("tmp_recordDecisions", bigMap);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        Configuration configuration = this.knowledgeBase.getConfiguration();
        AbstractTrainingParameters abstractTrainingParameters = (AbstractTrainingParameters) this.knowledgeBase.getTrainingParameters();
        AbstractModelParameters abstractModelParameters = (AbstractModelParameters) this.knowledgeBase.getModelParameters();
        resetBundle();
        int size = dataframe.size();
        Set<Object> classes = abstractModelParameters.getClasses();
        Iterator<Record> it = dataframe.iterator();
        while (it.hasNext()) {
            classes.add(it.next().getY());
        }
        AssociativeArray associativeArray = new AssociativeArray();
        Iterator<Integer> it2 = dataframe.index().iterator();
        while (it2.hasNext()) {
            associativeArray.put(it2.next(), Double.valueOf(1.0d / size));
        }
        AbstractTrainer.AbstractTrainingParameters weakClassifierTrainingParameters = abstractTrainingParameters.getWeakClassifierTrainingParameters();
        int maxWeakClassifiers = abstractTrainingParameters.getMaxWeakClassifiers();
        int i = 0;
        int i2 = 0;
        while (i < maxWeakClassifiers) {
            this.logger.debug("Training Weak learner {}", Integer.valueOf(i));
            Dataframe subset = dataframe.getSubset(SimpleRandomSampling.weightedSampling(associativeArray, size, true).toFlatDataList());
            AbstractClassifier abstractClassifier = (AbstractClassifier) MLBuilder.create(weakClassifierTrainingParameters, configuration);
            abstractClassifier.fit(subset);
            subset.close();
            abstractClassifier.predict(dataframe);
            Status updateObservationAndClassifierWeights = updateObservationAndClassifierWeights(dataframe, associativeArray);
            if (updateObservationAndClassifierWeights == Status.IGNORE) {
                abstractClassifier.close();
            } else {
                this.bundle.put(STORAGE_INDICATOR + i, abstractClassifier);
            }
            if (updateObservationAndClassifierWeights == Status.STOP) {
                this.logger.debug("Skipping further training due to low error");
                return;
            }
            if (updateObservationAndClassifierWeights != Status.IGNORE) {
                if (updateObservationAndClassifierWeights == Status.NEXT) {
                    i2 = 0;
                }
                i++;
            } else if (i2 >= 2) {
                this.logger.debug("Too many retries, skipping further training");
                return;
            } else {
                this.logger.debug("Ignoring last weak learner due to high error");
                i2++;
            }
        }
    }

    protected abstract Status updateObservationAndClassifierWeights(Dataframe dataframe, AssociativeArray associativeArray);

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer, com.datumbox.framework.core.common.interfaces.Savable
    public void save(String str) {
        initBundle();
        super.save(str);
        this.bundle.save(createKnowledgeBaseName(str, this.knowledgeBase.getConfiguration().getStorageConfiguration().getStorageNameSeparator()));
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer, com.datumbox.framework.core.common.interfaces.Savable
    public void delete() {
        initBundle();
        this.bundle.delete();
        super.delete();
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer, java.lang.AutoCloseable
    public void close() {
        initBundle();
        this.bundle.close();
        super.close();
    }

    private void resetBundle() {
        this.bundle.delete();
    }

    private void initBundle() {
        Configuration configuration = this.knowledgeBase.getConfiguration();
        StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
        AbstractModelParameters abstractModelParameters = (AbstractModelParameters) this.knowledgeBase.getModelParameters();
        AbstractTrainingParameters abstractTrainingParameters = (AbstractTrainingParameters) this.knowledgeBase.getTrainingParameters();
        String storageNameSeparator = configuration.getStorageConfiguration().getStorageNameSeparator();
        Class tClass = abstractTrainingParameters.getWeakClassifierTrainingParameters().getTClass();
        int min = Math.min(abstractModelParameters.getWeakClassifierWeights().size(), abstractTrainingParameters.getMaxWeakClassifiers());
        for (int i = 0; i < min; i++) {
            String str = STORAGE_INDICATOR + i;
            if (!this.bundle.containsKey(str)) {
                this.bundle.put(str, MLBuilder.load(tClass, storageEngine.getStorageName() + storageNameSeparator + str, configuration));
            }
        }
    }
}
