package com.datumbox.framework.core.machinelearning.common.interfaces;

import com.datumbox.framework.common.concurrency.ConcurrencyConfiguration;
import com.datumbox.framework.common.concurrency.ForkJoinStream;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import java.io.Serializable;
import java.util.Map;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/common/interfaces/PredictParallelizable.class */
public interface PredictParallelizable extends Parallelizable {

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/common/interfaces/PredictParallelizable$Prediction.class */
    public static class Prediction implements Serializable {
        private static final long serialVersionUID = 1;
        private final Object yPredicted;
        private final AssociativeArray yPredictedProbabilities;

        public Prediction(Object obj, AssociativeArray associativeArray) {
            this.yPredicted = obj;
            this.yPredictedProbabilities = associativeArray;
        }

        public Object getYPredicted() {
            return this.yPredicted;
        }

        public AssociativeArray getYPredictedProbabilities() {
            return this.yPredictedProbabilities;
        }
    }

    Prediction _predictRecord(Record record);

    default void _predictDatasetParallel(Dataframe dataframe, Map<Integer, Prediction> map, ConcurrencyConfiguration concurrencyConfiguration) {
        ForkJoinStream forkJoinStream = new ForkJoinStream(concurrencyConfiguration);
        forkJoinStream.forEach(StreamMethods.stream(dataframe.entries(), isParallelized()), entry -> {
            map.put(entry.getKey(), _predictRecord((Record) entry.getValue()));
        });
        forkJoinStream.forEach(StreamMethods.stream(dataframe.entries(), isParallelized()), entry2 -> {
            Integer num = (Integer) entry2.getKey();
            Record record = (Record) entry2.getValue();
            Prediction prediction = (Prediction) map.get(num);
            dataframe._unsafe_set(num, new Record(record.getX(), record.getY(), prediction.getYPredicted(), prediction.getYPredictedProbabilities()));
        });
    }

    default void _predictDatasetParallel(Dataframe dataframe, StorageEngine storageEngine, ConcurrencyConfiguration concurrencyConfiguration) {
        Map<Integer, Prediction> bigMap = storageEngine.getBigMap("tmp_resultsBuffer", Integer.class, Prediction.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_DISK, true, true);
        _predictDatasetParallel(dataframe, bigMap, concurrencyConfiguration);
        storageEngine.dropBigMap("tmp_resultsBuffer", bigMap);
    }
}
