package com.datumbox.framework.core.machinelearning.classification;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.ForkJoinStream;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.storage.interfaces.BigMap;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.mathematics.regularization.L2Regularizer;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeSet;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/OrdinalRegression.class */
public class OrdinalRegression extends AbstractClassifier<ModelParameters, TrainingParameters> implements PredictParallelizable, TrainParallelizable {
    private boolean parallelized;
    protected final ForkJoinStream streamExecutor;

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/OrdinalRegression$ModelParameters.class */
    public static class ModelParameters extends AbstractClassifier.AbstractModelParameters {
        private static final long serialVersionUID = 1;

        @BigMap(keyClass = Object.class, valueClass = Double.class, mapType = StorageEngine.MapType.HASHMAP, storageHint = StorageEngine.StorageHint.IN_MEMORY, concurrent = false)
        private Map<Object, Double> weights;
        private Map<Object, Double> thitas;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
            this.thitas = new HashMap();
        }

        public Map<Object, Double> getWeights() {
            return this.weights;
        }

        protected void setWeights(Map<Object, Double> map) {
            this.weights = map;
        }

        public Map<Object, Double> getThitas() {
            return this.thitas;
        }

        protected void setThitas(Map<Object, Double> map) {
            this.thitas = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/OrdinalRegression$TrainingParameters.class */
    public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
        private int totalIterations = 100;
        private double learningRate = 0.1d;
        private double l2 = 0.0d;

        public int getTotalIterations() {
            return this.totalIterations;
        }

        public void setTotalIterations(int i) {
            this.totalIterations = i;
        }

        public double getLearningRate() {
            return this.learningRate;
        }

        public void setLearningRate(double d) {
            this.learningRate = d;
        }

        public double getL2() {
            return this.l2;
        }

        public void setL2(double d) {
            this.l2 = d;
        }
    }

    protected OrdinalRegression(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
        this.parallelized = true;
        this.streamExecutor = new ForkJoinStream(this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    protected OrdinalRegression(String str, Configuration configuration) {
        super(str, configuration);
        this.parallelized = true;
        this.streamExecutor = new ForkJoinStream(this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public void setParallelized(boolean z) {
        this.parallelized = z;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    protected void _predict(Dataframe dataframe) {
        _predictDatasetParallel(dataframe, this.knowledgeBase.getStorageEngine(), this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable
    public PredictParallelizable.Prediction _predictRecord(Record record) {
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        AssociativeArray hypothesisFunction = hypothesisFunction(record.getX(), getPreviousThitaMappings(), modelParameters.getWeights(), modelParameters.getThitas());
        return new PredictParallelizable.Prediction(getSelectedClassFromClassScores(hypothesisFunction), hypothesisFunction);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) this.knowledgeBase.getTrainingParameters();
        Map<Object, Double> weights = modelParameters.getWeights();
        Map<Object, Double> thitas = modelParameters.getThitas();
        TreeSet treeSet = new TreeSet();
        Iterator<Record> it = dataframe.iterator();
        while (it.hasNext()) {
            Object y = it.next().getY();
            treeSet.add(y);
            thitas.put(y, Double.valueOf(0.0d));
        }
        thitas.put(treeSet.last(), Double.valueOf(Double.POSITIVE_INFINITY));
        modelParameters.getClasses().addAll(treeSet);
        Iterator<Object> it2 = dataframe.getXDataTypes().keySet().iterator();
        while (it2.hasNext()) {
            weights.put(it2.next(), Double.valueOf(0.0d));
        }
        Map<Object, Object> previousThitaMappings = getPreviousThitaMappings();
        double d = Double.POSITIVE_INFINITY;
        double learningRate = trainingParameters.getLearningRate();
        int totalIterations = trainingParameters.getTotalIterations();
        StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
        for (int i = 0; i < totalIterations; i++) {
            this.logger.debug("Iteration {}", Integer.valueOf(i));
            Map<? extends Object, ? extends Double> hashMap = new HashMap<>();
            Map<? extends Object, ? extends Double> bigMap = storageEngine.getBigMap("tmp_newWeights", Object.class, Double.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_MEMORY, false, true);
            hashMap.putAll(thitas);
            bigMap.putAll(weights);
            batchGradientDescent(dataframe, previousThitaMappings, bigMap, hashMap, learningRate);
            double calculateError = calculateError(dataframe, previousThitaMappings, bigMap, hashMap);
            if (calculateError > d) {
                learningRate /= 2.0d;
            } else {
                learningRate *= 1.05d;
                d = calculateError;
                weights.clear();
                weights.putAll(bigMap);
                thitas.clear();
                thitas.putAll(hashMap);
            }
            storageEngine.dropBigMap("tmp_newWeights", bigMap);
        }
    }

    private void batchGradientDescent(Dataframe dataframe, Map<Object, Object> map, Map<Object, Double> map2, Map<Object, Double> map3, double d) {
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        double size = (-d) / dataframe.size();
        Map<Object, Double> weights = modelParameters.getWeights();
        Map<Object, Double> thitas = modelParameters.getThitas();
        this.streamExecutor.forEach(StreamMethods.stream(dataframe.stream(), isParallelized()), record -> {
            Object y = record.getY();
            Object obj = map.get(y);
            double xTw = xTw(record.getX(), weights);
            double g = g(xTw - ((Double) thitas.get(y)).doubleValue());
            double g2 = obj != null ? g(((Double) thitas.get(obj)).doubleValue() - xTw) : 0.0d;
            double d2 = (g - g2) * size;
            synchronized (map2) {
                for (Map.Entry<Object, Object> entry : record.getX().entrySet()) {
                    Object key = entry.getKey();
                    map2.put(key, Double.valueOf(((Double) map2.get(key)).doubleValue() + (TypeInference.toDouble(entry.getValue()).doubleValue() * d2)));
                }
            }
            synchronized (map3) {
                map3.put(y, Double.valueOf(((Double) map3.get(y)).doubleValue() + (size * (-g))));
                if (obj != null) {
                    map3.put(obj, Double.valueOf(((Double) map3.get(obj)).doubleValue() + (size * g2)));
                }
            }
        });
        L2Regularizer.updateWeights(((TrainingParameters) this.knowledgeBase.getTrainingParameters()).getL2(), d, weights, map2);
    }

    private AssociativeArray hypothesisFunction(AssociativeArray associativeArray, Map<Object, Object> map, Map<Object, Double> map2, Map<Object, Double> map3) {
        AssociativeArray associativeArray2 = new AssociativeArray();
        double xTw = xTw(associativeArray, map2);
        for (Object obj : ((ModelParameters) this.knowledgeBase.getModelParameters()).getClasses()) {
            Object obj2 = map.get(obj);
            if (obj2 != null) {
                associativeArray2.put(obj, Double.valueOf(g(map3.get(obj).doubleValue() - xTw) - g(map3.get(obj2).doubleValue() - xTw)));
            } else {
                associativeArray2.put(obj, Double.valueOf(g(map3.get(obj).doubleValue() - xTw)));
            }
        }
        return associativeArray2;
    }

    private double calculateError(Dataframe dataframe, Map<Object, Object> map, Map<Object, Double> map2, Map<Object, Double> map3) {
        return (this.streamExecutor.sum(StreamMethods.stream(dataframe.stream(), isParallelized()).mapToDouble(record -> {
            double d = 0.0d;
            double xTw = xTw(record.getX(), map2);
            Object y = record.getY();
            Object obj = map.get(y);
            if (obj != null) {
                d = 0.0d + h(((Double) map3.get(obj)).doubleValue() - xTw);
            }
            return d + h(xTw - ((Double) map3.get(y)).doubleValue());
        })) / dataframe.size()) + L2Regularizer.estimatePenalty(((TrainingParameters) this.knowledgeBase.getTrainingParameters()).getL2(), map2);
    }

    private double h(double d) {
        if (d > 30.0d) {
            return d;
        }
        if (d < -30.0d) {
            return 0.0d;
        }
        return Math.log(1.0d + Math.exp(d));
    }

    private double g(double d) {
        if (d > 30.0d) {
            return 1.0d;
        }
        if (d < -30.0d) {
            return 0.0d;
        }
        return 1.0d / (1.0d + Math.exp(-d));
    }

    private double xTw(AssociativeArray associativeArray, Map<Object, Double> map) {
        Double d;
        double d2 = 0.0d;
        for (Map.Entry<Object, Object> entry : associativeArray.entrySet()) {
            Double d3 = TypeInference.toDouble(entry.getValue());
            if (d3 != null && d3.doubleValue() != 0.0d && (d = map.get(entry.getKey())) != null) {
                d2 += d3.doubleValue() * d.doubleValue();
            }
        }
        return d2;
    }

    private Map<Object, Object> getPreviousThitaMappings() {
        HashMap hashMap = new HashMap();
        Object obj = null;
        for (Object obj2 : ((ModelParameters) this.knowledgeBase.getModelParameters()).getClasses()) {
            hashMap.put(obj2, obj);
            obj = obj2;
        }
        return hashMap;
    }
}
