/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.classification;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.ForkJoinStream;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.storage.interfaces.BigMap;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.mathematics.regularization.L2Regularizer;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

public class OrdinalRegression
extends AbstractClassifier<ModelParameters, TrainingParameters>
implements PredictParallelizable,
TrainParallelizable {
    private boolean parallelized = true;
    protected final ForkJoinStream streamExecutor = new ForkJoinStream(this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());

    protected OrdinalRegression(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected OrdinalRegression(String storageName, Configuration configuration) {
        super(storageName, configuration);
    }

    @Override
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override
    public void setParallelized(boolean parallelized) {
        this.parallelized = parallelized;
    }

    @Override
    protected void _predict(Dataframe newData) {
        this._predictDatasetParallel(newData, this.knowledgeBase.getStorageEngine(), this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    @Override
    public PredictParallelizable.Prediction _predictRecord(Record r) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        AssociativeArray predictionProbabilities = this.hypothesisFunction(r.getX(), this.getPreviousThitaMappings(), modelParameters.getWeights(), modelParameters.getThitas());
        Object predictedClass = this.getSelectedClassFromClassScores(predictionProbabilities);
        return new PredictParallelizable.Prediction(predictedClass, predictionProbabilities);
    }

    @Override
    protected void _fit(Dataframe trainingData) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters)this.knowledgeBase.getTrainingParameters();
        Map<Object, Double> weights = modelParameters.getWeights();
        Map<Object, Double> thitas = modelParameters.getThitas();
        TreeSet<Iterator<Object>> sortedClasses = new TreeSet<Iterator<Object>>();
        for (Record r : trainingData) {
            Iterator<Object> theClass = r.getY();
            sortedClasses.add(theClass);
            thitas.put(theClass, 0.0);
        }
        Object finalClass = sortedClasses.last();
        thitas.put(finalClass, Double.POSITIVE_INFINITY);
        Set<Object> classesSet = modelParameters.getClasses();
        classesSet.addAll(sortedClasses);
        for (Object feature : trainingData.getXDataTypes().keySet()) {
            weights.put(feature, 0.0);
        }
        Map<Object, Object> previousThitaMapping = this.getPreviousThitaMappings();
        double minError = Double.POSITIVE_INFINITY;
        double learningRate = trainingParameters.getLearningRate();
        int totalIterations = trainingParameters.getTotalIterations();
        StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
        for (int iteration = 0; iteration < totalIterations; ++iteration) {
            this.logger.debug("Iteration {}", (Object)iteration);
            HashMap<Object, Double> tmp_newThitas = new HashMap<Object, Double>();
            Map<Object, Double> tmp_newWeights = storageEngine.getBigMap("tmp_newWeights", Object.class, Double.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_MEMORY, false, true);
            tmp_newThitas.putAll(thitas);
            tmp_newWeights.putAll(weights);
            this.batchGradientDescent(trainingData, previousThitaMapping, tmp_newWeights, tmp_newThitas, learningRate);
            double newError = this.calculateError(trainingData, previousThitaMapping, tmp_newWeights, tmp_newThitas);
            if (newError > minError) {
                learningRate /= 2.0;
            } else {
                learningRate *= 1.05;
                minError = newError;
                weights.clear();
                weights.putAll(tmp_newWeights);
                thitas.clear();
                thitas.putAll(tmp_newThitas);
            }
            storageEngine.dropBigMap("tmp_newWeights", tmp_newWeights);
        }
    }

    private void batchGradientDescent(Dataframe trainingData, Map<Object, Object> previousThitaMapping, Map<Object, Double> newWeights, Map<Object, Double> newThitas, double learningRate) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        double multiplier = -learningRate / (double)trainingData.size();
        Map<Object, Double> weights = modelParameters.getWeights();
        Map<Object, Double> thitas = modelParameters.getThitas();
        this.streamExecutor.forEach(StreamMethods.stream(trainingData.stream(), this.isParallelized()), r -> {
            Object rClass = r.getY();
            Object rPreviousClass = previousThitaMapping.get(rClass);
            double xTw = this.xTw(r.getX(), weights);
            double gOfCurrent = this.g(xTw - (Double)thitas.get(rClass));
            double gOfPrevious = rPreviousClass != null ? this.g((Double)thitas.get(rPreviousClass) - xTw) : 0.0;
            double dtG_multiplier = (gOfCurrent - gOfPrevious) * multiplier;
            Map map = newWeights;
            synchronized (map) {
                for (Map.Entry<Object, Object> entry : r.getX().entrySet()) {
                    Object column = entry.getKey();
                    Double xij = TypeInference.toDouble(entry.getValue());
                    double xij_dtG_multiplier = xij * dtG_multiplier;
                    newWeights.put(column, (Double)newWeights.get(column) + xij_dtG_multiplier);
                }
            }
            map = newThitas;
            synchronized (map) {
                newThitas.put(rClass, (Double)newThitas.get(rClass) + multiplier * -gOfCurrent);
                if (rPreviousClass != null) {
                    newThitas.put(rPreviousClass, (Double)newThitas.get(rPreviousClass) + multiplier * gOfPrevious);
                }
            }
        });
        L2Regularizer.updateWeights(((TrainingParameters)this.knowledgeBase.getTrainingParameters()).getL2(), learningRate, weights, newWeights);
    }

    private AssociativeArray hypothesisFunction(AssociativeArray x, Map<Object, Object> previousThitaMapping, Map<Object, Double> weights, Map<Object, Double> thitas) {
        AssociativeArray probabilities = new AssociativeArray();
        double xTw = this.xTw(x, weights);
        Set<Object> classesSet = ((ModelParameters)this.knowledgeBase.getModelParameters()).getClasses();
        for (Object theClass : classesSet) {
            Object previousClass = previousThitaMapping.get(theClass);
            if (previousClass != null) {
                probabilities.put(theClass, this.g(thitas.get(theClass) - xTw) - this.g(thitas.get(previousClass) - xTw));
                continue;
            }
            probabilities.put(theClass, this.g(thitas.get(theClass) - xTw));
        }
        return probabilities;
    }

    private double calculateError(Dataframe trainingData, Map<Object, Object> previousThitaMapping, Map<Object, Double> weights, Map<Object, Double> thitas) {
        double error = this.streamExecutor.sum(StreamMethods.stream(trainingData.stream(), this.isParallelized()).mapToDouble(r -> {
            double e = 0.0;
            double xTw = this.xTw(r.getX(), weights);
            Object theClass = r.getY();
            Object previousClass = previousThitaMapping.get(theClass);
            if (previousClass != null) {
                e += this.h((Double)thitas.get(previousClass) - xTw);
            }
            return e += this.h(xTw - (Double)thitas.get(theClass));
        }));
        error /= (double)trainingData.size();
        return error += L2Regularizer.estimatePenalty(((TrainingParameters)this.knowledgeBase.getTrainingParameters()).getL2(), weights);
    }

    private double h(double z) {
        if (z > 30.0) {
            return z;
        }
        if (z < -30.0) {
            return 0.0;
        }
        return Math.log(1.0 + Math.exp(z));
    }

    private double g(double z) {
        if (z > 30.0) {
            return 1.0;
        }
        if (z < -30.0) {
            return 0.0;
        }
        return 1.0 / (1.0 + Math.exp(-z));
    }

    private double xTw(AssociativeArray x, Map<Object, Double> weights) {
        double xTw = 0.0;
        for (Map.Entry<Object, Object> entry : x.entrySet()) {
            Object column;
            Double w;
            Double value = TypeInference.toDouble(entry.getValue());
            if (value == null || value == 0.0 || (w = weights.get(column = entry.getKey())) == null) continue;
            xTw += value * w;
        }
        return xTw;
    }

    private Map<Object, Object> getPreviousThitaMappings() {
        HashMap<Object, Object> previousThitaMapping = new HashMap<Object, Object>();
        Object previousThita = null;
        for (Object thita : ((ModelParameters)this.knowledgeBase.getModelParameters()).getClasses()) {
            previousThitaMapping.put(thita, previousThita);
            previousThita = thita;
        }
        return previousThitaMapping;
    }

    public static class TrainingParameters
    extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1L;
        private int totalIterations = 100;
        private double learningRate = 0.1;
        private double l2 = 0.0;

        public int getTotalIterations() {
            return this.totalIterations;
        }

        public void setTotalIterations(int totalIterations) {
            this.totalIterations = totalIterations;
        }

        public double getLearningRate() {
            return this.learningRate;
        }

        public void setLearningRate(double learningRate) {
            this.learningRate = learningRate;
        }

        public double getL2() {
            return this.l2;
        }

        public void setL2(double l2) {
            this.l2 = l2;
        }
    }

    public static class ModelParameters
    extends AbstractClassifier.AbstractModelParameters {
        private static final long serialVersionUID = 1L;
        @BigMap(keyClass=Object.class, valueClass=Double.class, mapType=StorageEngine.MapType.HASHMAP, storageHint=StorageEngine.StorageHint.IN_MEMORY, concurrent=false)
        private Map<Object, Double> weights;
        private Map<Object, Double> thitas = new HashMap<Object, Double>();

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }

        public Map<Object, Double> getWeights() {
            return this.weights;
        }

        protected void setWeights(Map<Object, Double> weights) {
            this.weights = weights;
        }

        public Map<Object, Double> getThitas() {
            return this.thitas;
        }

        protected void setThitas(Map<Object, Double> thitas) {
            this.thitas = thitas;
        }
    }
}

