package com.datumbox.framework.core.machinelearning.classification;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.storage.interfaces.BigMap;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.common.utilities.RandomGenerator;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/SupportVectorMachine.class */
public class SupportVectorMachine extends AbstractClassifier<ModelParameters, TrainingParameters> implements PredictParallelizable {
    private boolean parallelized;

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/SupportVectorMachine$ModelParameters.class */
    public static class ModelParameters extends AbstractClassifier.AbstractModelParameters {
        private static final long serialVersionUID = 1;

        @BigMap(keyClass = Object.class, valueClass = Integer.class, mapType = StorageEngine.MapType.HASHMAP, storageHint = StorageEngine.StorageHint.IN_MEMORY, concurrent = false)
        private Map<Object, Integer> featureIds;
        private Map<Object, Integer> classIds;
        private svm_model svmModel;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
            this.classIds = new HashMap();
        }

        public Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        protected void setFeatureIds(Map<Object, Integer> map) {
            this.featureIds = map;
        }

        public svm_model getSvmModel() {
            return this.svmModel;
        }

        protected void setSvmModel(svm_model svm_modelVar) {
            this.svmModel = svm_modelVar;
        }

        public Map<Object, Integer> getClassIds() {
            return this.classIds;
        }

        protected void setClassIds(Map<Object, Integer> map) {
            this.classIds = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/SupportVectorMachine$TrainingParameters.class */
    public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
        private svm_parameter svmParameter = new svm_parameter();

        public TrainingParameters() {
            this.svmParameter.svm_type = 0;
            this.svmParameter.kernel_type = 0;
            this.svmParameter.degree = 3;
            this.svmParameter.gamma = 0.0d;
            this.svmParameter.coef0 = 0.0d;
            this.svmParameter.nu = 0.5d;
            this.svmParameter.cache_size = 100.0d;
            this.svmParameter.C = 1.0d;
            this.svmParameter.eps = 0.001d;
            this.svmParameter.p = 0.1d;
            this.svmParameter.shrinking = 1;
            this.svmParameter.probability = 1;
            this.svmParameter.nr_weight = 0;
            this.svmParameter.weight_label = new int[0];
            this.svmParameter.weight = new double[0];
        }

        public svm_parameter getSvmParameter() {
            return this.svmParameter;
        }

        public void setSvmParameter(svm_parameter svm_parameterVar) {
            this.svmParameter = svm_parameterVar;
        }
    }

    protected SupportVectorMachine(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
        svm.rand.setSeed(RandomGenerator.getThreadLocalRandom().nextLong());
        this.parallelized = true;
    }

    protected SupportVectorMachine(String str, Configuration configuration) {
        super(str, configuration);
        svm.rand.setSeed(RandomGenerator.getThreadLocalRandom().nextLong());
        this.parallelized = true;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public void setParallelized(boolean z) {
        this.parallelized = z;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    protected void _predict(Dataframe dataframe) {
        _predictDatasetParallel(dataframe, this.knowledgeBase.getStorageEngine(), this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable
    public PredictParallelizable.Prediction _predictRecord(Record record) {
        AssociativeArray calculateClassScores = calculateClassScores(record.getX());
        Object selectedClassFromClassScores = getSelectedClassFromClassScores(calculateClassScores);
        Descriptives.normalize(calculateClassScores);
        return new PredictParallelizable.Prediction(selectedClassFromClassScores, calculateClassScores);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        ((TrainingParameters) this.knowledgeBase.getTrainingParameters()).getSvmParameter().probability = 1;
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        Map<Object, Integer> classIds = modelParameters.getClassIds();
        Set<Object> classes = modelParameters.getClasses();
        int i = 0;
        int i2 = 0;
        Iterator<Record> it = dataframe.iterator();
        while (it.hasNext()) {
            Record next = it.next();
            Object y = next.getY();
            if (classes.add(y)) {
                int i3 = i;
                i++;
                classIds.put(y, Integer.valueOf(i3));
            }
            Iterator<Map.Entry<Object, Object>> it2 = next.getX().entrySet().iterator();
            while (it2.hasNext()) {
                if (featureIds.putIfAbsent(it2.next().getKey(), Integer.valueOf(i2)) == null) {
                    i2++;
                }
            }
        }
        libSVMTrainer(dataframe);
    }

    private void libSVMTrainer(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        Map<Object, Integer> classIds = modelParameters.getClassIds();
        int size = dataframe.size();
        int size2 = featureIds.size();
        svm_problem svm_problemVar = new svm_problem();
        svm_problemVar.l = size;
        svm_problemVar.y = new double[size];
        svm_problemVar.x = new svm_node[size][size2];
        int i = 0;
        for (Record record : dataframe.values()) {
            svm_problemVar.y[i] = classIds.get(record.getY()).intValue();
            for (Map.Entry<Object, Object> entry : record.getX().entrySet()) {
                int intValue = featureIds.get(entry.getKey()).intValue();
                Double d = TypeInference.toDouble(entry.getValue());
                if (d == null) {
                    d = Double.valueOf(0.0d);
                }
                svm_node svm_nodeVar = new svm_node();
                svm_nodeVar.index = intValue + 1;
                svm_nodeVar.value = d.doubleValue();
                svm_problemVar.x[i][intValue] = svm_nodeVar;
            }
            for (int i2 = 0; i2 < size2; i2++) {
                if (svm_problemVar.x[i][i2] == null) {
                    svm_node svm_nodeVar2 = new svm_node();
                    svm_nodeVar2.index = i2 + 1;
                    svm_nodeVar2.value = 0.0d;
                    svm_problemVar.x[i][i2] = svm_nodeVar2;
                }
            }
            i++;
        }
        svm_parameter svmParameter = ((TrainingParameters) this.knowledgeBase.getTrainingParameters()).getSvmParameter();
        svm.svm_set_print_string_function(str -> {
            if (str != null) {
                this.logger.debug(str.trim());
            }
        });
        modelParameters.setSvmModel(svm.svm_train(svm_problemVar, svmParameter));
    }

    private AssociativeArray calculateClassScores(AssociativeArray associativeArray) {
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        Map<Object, Integer> classIds = modelParameters.getClassIds();
        svm_model svmModel = modelParameters.getSvmModel();
        int size = featureIds.size();
        int intValue = modelParameters.getC().intValue();
        svm_node[] svm_nodeVarArr = new svm_node[size];
        for (Map.Entry<Object, Object> entry : associativeArray.entrySet()) {
            Integer num = featureIds.get(entry.getKey());
            if (num != null) {
                Double d = TypeInference.toDouble(entry.getValue());
                if (d == null) {
                    d = Double.valueOf(0.0d);
                }
                svm_node svm_nodeVar = new svm_node();
                svm_nodeVar.index = num.intValue() + 1;
                svm_nodeVar.value = d.doubleValue();
                svm_nodeVarArr[num.intValue()] = svm_nodeVar;
            }
        }
        for (int i = 0; i < size; i++) {
            if (svm_nodeVarArr[i] == null) {
                svm_node svm_nodeVar2 = new svm_node();
                svm_nodeVar2.index = i + 1;
                svm_nodeVar2.value = 0.0d;
                svm_nodeVarArr[i] = svm_nodeVar2;
            }
        }
        svm.svm_get_labels(svmModel, new int[intValue]);
        double[] dArr = new double[intValue];
        svm.svm_predict_probability(svmModel, svm_nodeVarArr, dArr);
        AssociativeArray associativeArray2 = new AssociativeArray();
        for (Map.Entry<Object, Integer> entry2 : classIds.entrySet()) {
            associativeArray2.put(entry2.getKey(), Double.valueOf(dArr[entry2.getValue().intValue()]));
        }
        return associativeArray2;
    }
}
