package com.datumbox.framework.core.machinelearning.clustering;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.ForkJoinStream;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.storage.interfaces.BigMap;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.common.utilities.MapMethods;
import com.datumbox.framework.core.common.utilities.PHPMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.mathematics.distances.Distance;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.core.statistics.sampling.SimpleRandomSampling;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/Kmeans.class */
public class Kmeans extends AbstractClusterer<Cluster, ModelParameters, TrainingParameters> implements PredictParallelizable, TrainParallelizable {
    private boolean parallelized;
    protected final ForkJoinStream streamExecutor;

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/Kmeans$Cluster.class */
    public static class Cluster extends AbstractClusterer.AbstractCluster {
        private static final long serialVersionUID = 1;
        private Record centroid;
        private final AssociativeArray xi_sum;

        protected Cluster(int i) {
            super(Integer.valueOf(i));
            this.centroid = new Record(new AssociativeArray(), null);
            this.xi_sum = new AssociativeArray();
        }

        public Record getCentroid() {
            return this.centroid;
        }

        protected boolean updateClusterParameters() {
            boolean z = false;
            AssociativeArray copy = this.xi_sum.copy();
            if (this.size > 0) {
                copy.multiplyValues(1.0d / this.size);
            }
            if (!this.centroid.getX().equals(copy)) {
                z = true;
                this.centroid = new Record(copy, this.centroid.getY());
            }
            return z;
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer.AbstractCluster
        protected void add(Record record) {
            this.size++;
            this.xi_sum.addValues(record.getX());
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer.AbstractCluster
        protected void remove(Record record) {
            throw new UnsupportedOperationException("Remove operation is not supported.");
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer.AbstractCluster
        protected void clear() {
            this.xi_sum.clear();
        }

        protected void reset() {
            this.xi_sum.clear();
            this.size = 0;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/Kmeans$ModelParameters.class */
    public static class ModelParameters extends AbstractClusterer.AbstractModelParameters<Cluster> {
        private static final long serialVersionUID = 1;
        private int totalIterations;

        @BigMap(keyClass = Object.class, valueClass = Double.class, mapType = StorageEngine.MapType.HASHMAP, storageHint = StorageEngine.StorageHint.IN_MEMORY, concurrent = true)
        private Map<Object, Double> featureWeights;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }

        public int getTotalIterations() {
            return this.totalIterations;
        }

        protected void setTotalIterations(int i) {
            this.totalIterations = i;
        }

        public Map<Object, Double> getFeatureWeights() {
            return this.featureWeights;
        }

        protected void setFeatureWeights(Map<Object, Double> map) {
            this.featureWeights = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/Kmeans$TrainingParameters.class */
    public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
        private int k = 2;
        private Initialization initializationMethod = Initialization.PLUS_PLUS;
        private Distance distanceMethod = Distance.EUCLIDIAN;
        private int maxIterations = 200;
        private double subsetFurthestFirstcValue = 2.0d;
        private double categoricalGamaMultiplier = 1.0d;
        private boolean weighted = false;

        /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/Kmeans$TrainingParameters$Distance.class */
        public enum Distance {
            EUCLIDIAN,
            MANHATTAN
        }

        /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/Kmeans$TrainingParameters$Initialization.class */
        public enum Initialization {
            FORGY,
            RANDOM_PARTITION,
            SET_FIRST_K,
            FURTHEST_FIRST,
            SUBSET_FURTHEST_FIRST,
            PLUS_PLUS
        }

        public int getK() {
            return this.k;
        }

        public void setK(int i) {
            this.k = i;
        }

        public Initialization getInitializationMethod() {
            return this.initializationMethod;
        }

        public void setInitializationMethod(Initialization initialization) {
            this.initializationMethod = initialization;
        }

        public Distance getDistanceMethod() {
            return this.distanceMethod;
        }

        public void setDistanceMethod(Distance distance) {
            this.distanceMethod = distance;
        }

        public int getMaxIterations() {
            return this.maxIterations;
        }

        public void setMaxIterations(int i) {
            this.maxIterations = i;
        }

        public double getSubsetFurthestFirstcValue() {
            return this.subsetFurthestFirstcValue;
        }

        public void setSubsetFurthestFirstcValue(double d) {
            this.subsetFurthestFirstcValue = d;
        }

        public double getCategoricalGamaMultiplier() {
            return this.categoricalGamaMultiplier;
        }

        public void setCategoricalGamaMultiplier(double d) {
            this.categoricalGamaMultiplier = d;
        }

        public boolean isWeighted() {
            return this.weighted;
        }

        public void setWeighted(boolean z) {
            this.weighted = z;
        }
    }

    protected Kmeans(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
        this.parallelized = true;
        this.streamExecutor = new ForkJoinStream(this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    protected Kmeans(String str, Configuration configuration) {
        super(str, configuration);
        this.parallelized = true;
        this.streamExecutor = new ForkJoinStream(this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public void setParallelized(boolean z) {
        this.parallelized = z;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    protected void _predict(Dataframe dataframe) {
        _predictDatasetParallel(dataframe, this.knowledgeBase.getStorageEngine(), this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable
    public PredictParallelizable.Prediction _predictRecord(Record record) {
        Map<Integer, Cluster> clusterMap = ((ModelParameters) this.knowledgeBase.getModelParameters()).getClusterMap();
        AssociativeArray associativeArray = new AssociativeArray();
        for (Map.Entry<Integer, Cluster> entry : clusterMap.entrySet()) {
            associativeArray.put(entry.getKey(), Double.valueOf(calculateDistance(record, entry.getValue().getCentroid())));
        }
        Descriptives.normalize(associativeArray);
        return new PredictParallelizable.Prediction(getSelectedClusterFromDistances(associativeArray), associativeArray);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        Set<Object> goldStandardClasses = ((ModelParameters) this.knowledgeBase.getModelParameters()).getGoldStandardClasses();
        Iterator<Record> it = dataframe.iterator();
        while (it.hasNext()) {
            Object y = it.next().getY();
            if (y != null) {
                goldStandardClasses.add(y);
            }
        }
        calculateFeatureWeights(dataframe);
        initializeClusters(dataframe);
        calculateClusters(dataframe);
        clearClusters();
    }

    private void calculateFeatureWeights(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) this.knowledgeBase.getTrainingParameters();
        Map<Object, TypeInference.DataType> xDataTypes = dataframe.getXDataTypes();
        Map<Object, Double> featureWeights = modelParameters.getFeatureWeights();
        if (!trainingParameters.isWeighted()) {
            double categoricalGamaMultiplier = trainingParameters.getCategoricalGamaMultiplier();
            this.streamExecutor.forEach(StreamMethods.stream(xDataTypes.entrySet().stream(), isParallelized()), entry -> {
                featureWeights.put(entry.getKey(), Double.valueOf(entry.getValue() != TypeInference.DataType.NUMERICAL ? categoricalGamaMultiplier : 1.0d));
            });
            return;
        }
        int size = dataframe.size();
        StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
        Map bigMap = storageEngine.getBigMap("tmp_categoricalFrequencies", Object.class, Double.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_MEMORY, true, true);
        Map bigMap2 = storageEngine.getBigMap("tmp_varianceSumX", Object.class, Double.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_MEMORY, true, true);
        Map bigMap3 = storageEngine.getBigMap("tmp_varianceSumXsquare", Object.class, Double.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_MEMORY, true, true);
        Iterator<Record> it = dataframe.iterator();
        while (it.hasNext()) {
            for (Map.Entry<Object, Object> entry2 : it.next().getX().entrySet()) {
                Double d = TypeInference.toDouble(entry2.getValue());
                if (d != null && d.doubleValue() != 0.0d) {
                    Object key = entry2.getKey();
                    if (xDataTypes.get(key) != TypeInference.DataType.NUMERICAL) {
                        bigMap.put(key, Double.valueOf(((Double) bigMap.getOrDefault(key, Double.valueOf(0.0d))).doubleValue() + 1.0d));
                    } else {
                        Double d2 = (Double) bigMap2.getOrDefault(key, Double.valueOf(0.0d));
                        Double d3 = (Double) bigMap3.getOrDefault(key, Double.valueOf(0.0d));
                        bigMap2.put(key, Double.valueOf(d2.doubleValue() + d.doubleValue()));
                        bigMap3.put(key, Double.valueOf(d3.doubleValue() + (d.doubleValue() * d.doubleValue())));
                    }
                }
            }
        }
        double categoricalGamaMultiplier2 = trainingParameters.getCategoricalGamaMultiplier();
        this.streamExecutor.forEach(StreamMethods.stream(xDataTypes.entrySet().stream(), isParallelized()), entry3 -> {
            double doubleValue;
            Object key2 = entry3.getKey();
            TypeInference.DataType dataType = (TypeInference.DataType) entry3.getValue();
            if (dataType != TypeInference.DataType.NUMERICAL) {
                double doubleValue2 = ((Double) bigMap.get(key2)).doubleValue() / size;
                doubleValue = 1.0d - (doubleValue2 * doubleValue2);
            } else {
                double doubleValue3 = ((Double) bigMap2.get(key2)).doubleValue() / size;
                doubleValue = 2.0d * ((((Double) bigMap3.get(key2)).doubleValue() / size) - (doubleValue3 * doubleValue3));
            }
            if (doubleValue > 0.0d) {
                doubleValue = 1.0d / doubleValue;
            }
            if (dataType != TypeInference.DataType.NUMERICAL) {
                doubleValue *= categoricalGamaMultiplier2;
            }
            featureWeights.put(key2, Double.valueOf(doubleValue));
        });
        storageEngine.dropBigMap("tmp_categoricalFrequencies", bigMap);
        storageEngine.dropBigMap("tmp_varianceSumX", bigMap);
        storageEngine.dropBigMap("tmp_varianceSumXsquare", bigMap);
    }

    private double calculateDistance(Record record, Record record2) {
        double manhattanWeighted;
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) this.knowledgeBase.getTrainingParameters();
        Map<Object, Double> featureWeights = modelParameters.getFeatureWeights();
        TrainingParameters.Distance distanceMethod = trainingParameters.getDistanceMethod();
        if (distanceMethod == TrainingParameters.Distance.EUCLIDIAN) {
            manhattanWeighted = Distance.euclideanWeighted(record.getX(), record2.getX(), featureWeights);
        } else {
            if (distanceMethod != TrainingParameters.Distance.MANHATTAN) {
                throw new IllegalArgumentException("Unsupported Distance method.");
            }
            manhattanWeighted = Distance.manhattanWeighted(record.getX(), record2.getX(), featureWeights);
        }
        return manhattanWeighted;
    }

    private Object getSelectedClusterFromDistances(AssociativeArray associativeArray) {
        return MapMethods.selectMinKeyValue(associativeArray).getKey();
    }

    private void initializeClusters(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) this.knowledgeBase.getTrainingParameters();
        int k = trainingParameters.getK();
        TrainingParameters.Initialization initializationMethod = trainingParameters.getInitializationMethod();
        Map<Integer, Cluster> clusterMap = modelParameters.getClusterMap();
        if (initializationMethod == TrainingParameters.Initialization.SET_FIRST_K || initializationMethod == TrainingParameters.Initialization.FORGY) {
            int i = 0;
            for (Record record : dataframe.values()) {
                if (i >= k) {
                    return;
                }
                Integer valueOf = Integer.valueOf(i);
                Cluster cluster = new Cluster(valueOf.intValue());
                cluster.add(record);
                cluster.updateClusterParameters();
                clusterMap.put(valueOf, cluster);
                i++;
            }
            return;
        }
        if (initializationMethod == TrainingParameters.Initialization.RANDOM_PARTITION) {
            int i2 = 0;
            for (Record record2 : dataframe.values()) {
                Integer valueOf2 = Integer.valueOf(i2 % k);
                Cluster cluster2 = clusterMap.get(valueOf2);
                if (cluster2 == null) {
                    cluster2 = new Cluster(valueOf2.intValue());
                }
                cluster2.add(record2);
                clusterMap.put(valueOf2, cluster2);
                i2++;
            }
            for (Map.Entry<Integer, Cluster> entry : clusterMap.entrySet()) {
                Integer key = entry.getKey();
                Cluster value = entry.getValue();
                value.updateClusterParameters();
                clusterMap.put(key, value);
            }
            return;
        }
        if (initializationMethod != TrainingParameters.Initialization.FURTHEST_FIRST && initializationMethod != TrainingParameters.Initialization.SUBSET_FURTHEST_FIRST) {
            if (initializationMethod == TrainingParameters.Initialization.PLUS_PLUS) {
                StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
                HashSet hashSet = new HashSet();
                for (int i3 = 0; i3 < k; i3++) {
                    Map bigMap = storageEngine.getBigMap("tmp_minClusterDistance", Object.class, Double.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_MEMORY, true, true);
                    AssociativeArray associativeArray = new AssociativeArray(bigMap);
                    this.streamExecutor.forEach(StreamMethods.stream(dataframe.entries(), isParallelized()), entry2 -> {
                        Integer num = (Integer) entry2.getKey();
                        Record record3 = (Record) entry2.getValue();
                        if (hashSet.contains(num)) {
                            return;
                        }
                        double d = 1.0d;
                        if (clusterMap.size() > 0) {
                            d = Double.MAX_VALUE;
                            Iterator it = clusterMap.values().iterator();
                            while (it.hasNext()) {
                                double calculateDistance = calculateDistance(record3, ((Cluster) it.next()).getCentroid());
                                if (calculateDistance < d) {
                                    d = calculateDistance;
                                }
                            }
                        }
                        associativeArray.put(num, Double.valueOf(d));
                    });
                    Descriptives.normalize(associativeArray);
                    Integer num = (Integer) SimpleRandomSampling.weightedSampling(associativeArray, 1, true).iterator().next();
                    storageEngine.dropBigMap("tmp_minClusterDistance", bigMap);
                    hashSet.add(num);
                    Integer valueOf3 = Integer.valueOf(clusterMap.size());
                    Cluster cluster3 = new Cluster(valueOf3.intValue());
                    cluster3.add(dataframe.get(num));
                    cluster3.updateClusterParameters();
                    clusterMap.put(valueOf3, cluster3);
                }
                return;
            }
            return;
        }
        int size = dataframe.size();
        if (initializationMethod == TrainingParameters.Initialization.SUBSET_FURTHEST_FIRST) {
            size = (int) Math.max(Math.ceil(trainingParameters.getSubsetFurthestFirstcValue() * k * PHPMethods.log(k, 2.0d)), k);
        }
        HashSet hashSet2 = new HashSet();
        for (int i4 = 0; i4 < k; i4++) {
            Integer num2 = null;
            double d = 0.0d;
            int i5 = 0;
            for (Map.Entry<Integer, Record> entry3 : dataframe.entries()) {
                Integer key2 = entry3.getKey();
                Record value2 = entry3.getValue();
                if (i5 > size) {
                    break;
                }
                if (!hashSet2.contains(key2)) {
                    double d2 = Double.MAX_VALUE;
                    Iterator<Cluster> it = clusterMap.values().iterator();
                    while (it.hasNext()) {
                        double calculateDistance = calculateDistance(value2, it.next().getCentroid());
                        if (calculateDistance < d2) {
                            d2 = calculateDistance;
                        }
                    }
                    if (d2 > d) {
                        d = d2;
                        num2 = key2;
                    }
                    i5++;
                }
            }
            hashSet2.add(num2);
            Integer valueOf4 = Integer.valueOf(clusterMap.size());
            Cluster cluster4 = new Cluster(valueOf4.intValue());
            cluster4.add(dataframe.get(num2));
            cluster4.updateClusterParameters();
            clusterMap.put(valueOf4, cluster4);
        }
    }

    private void calculateClusters(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) this.knowledgeBase.getTrainingParameters();
        Map<Integer, Cluster> clusterMap = modelParameters.getClusterMap();
        int maxIterations = trainingParameters.getMaxIterations();
        modelParameters.setTotalIterations(maxIterations);
        for (int i = 0; i < maxIterations; i++) {
            this.logger.debug("Iteration {}", Integer.valueOf(i));
            for (Map.Entry<Integer, Cluster> entry : clusterMap.entrySet()) {
                Integer key = entry.getKey();
                Cluster value = entry.getValue();
                value.reset();
                clusterMap.put(key, value);
            }
            Map bigMap = this.knowledgeBase.getStorageEngine().getBigMap("tmp_clusterAssignments", Integer.class, Integer.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_MEMORY, true, true);
            this.streamExecutor.forEach(StreamMethods.stream(dataframe.entries(), isParallelized()), entry2 -> {
                Integer num = (Integer) entry2.getKey();
                Record record = (Record) entry2.getValue();
                AssociativeArray associativeArray = new AssociativeArray();
                for (Map.Entry entry2 : clusterMap.entrySet()) {
                    associativeArray.put((Integer) entry2.getKey(), Double.valueOf(calculateDistance(record, ((Cluster) entry2.getValue()).getCentroid())));
                }
                bigMap.put(num, (Integer) getSelectedClusterFromDistances(associativeArray));
            });
            for (Map.Entry<Integer, Record> entry3 : dataframe.entries()) {
                Integer key2 = entry3.getKey();
                Record value2 = entry3.getValue();
                Integer num = (Integer) bigMap.get(key2);
                Cluster cluster = clusterMap.get(num);
                cluster.add(value2);
                clusterMap.put(num, cluster);
            }
            this.knowledgeBase.getStorageEngine().dropBigMap("tmp_clusterAssignments", bigMap);
            boolean z = false;
            for (Map.Entry<Integer, Cluster> entry4 : clusterMap.entrySet()) {
                Integer key3 = entry4.getKey();
                Cluster value3 = entry4.getValue();
                z |= value3.updateClusterParameters();
                clusterMap.put(key3, value3);
            }
            if (!z) {
                modelParameters.setTotalIterations(i);
                return;
            }
        }
    }
}
