package com.datumbox.framework.core.machinelearning.clustering;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.ForkJoinStream;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.common.utilities.MapMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.mathematics.distances.Distance;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerative.class */
public class HierarchicalAgglomerative extends AbstractClusterer<Cluster, ModelParameters, TrainingParameters> implements PredictParallelizable, TrainParallelizable {
    private boolean parallelized;
    protected final ForkJoinStream streamExecutor;

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerative$Cluster.class */
    public static class Cluster extends AbstractClusterer.AbstractCluster {
        private static final long serialVersionUID = 1;
        private Record centroid;
        private boolean active;
        private final AssociativeArray xi_sum;

        protected Cluster(int i) {
            super(Integer.valueOf(i));
            this.active = true;
            this.centroid = new Record(new AssociativeArray(), null);
            this.xi_sum = new AssociativeArray();
        }

        public Record getCentroid() {
            return this.centroid;
        }

        protected void merge(Cluster cluster) {
            this.xi_sum.addValues(cluster.xi_sum);
            this.size += cluster.size;
        }

        protected boolean updateClusterParameters() {
            boolean z = false;
            AssociativeArray copy = this.xi_sum.copy();
            if (this.size > 0) {
                copy.multiplyValues(1.0d / this.size);
            }
            if (!this.centroid.getX().equals(copy)) {
                z = true;
                this.centroid = new Record(copy, this.centroid.getY());
            }
            return z;
        }

        protected boolean isActive() {
            return this.active;
        }

        protected void setActive(boolean z) {
            this.active = z;
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer.AbstractCluster
        protected void add(Record record) {
            this.size++;
            this.xi_sum.addValues(record.getX());
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer.AbstractCluster
        protected void remove(Record record) {
            throw new UnsupportedOperationException("Remove operation is not supported.");
        }

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer.AbstractCluster
        protected void clear() {
            this.xi_sum.clear();
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerative$ModelParameters.class */
    public static class ModelParameters extends AbstractClusterer.AbstractModelParameters<Cluster> {
        private static final long serialVersionUID = 1;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerative$TrainingParameters.class */
    public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
        private Linkage linkageMethod = Linkage.COMPLETE;
        private Distance distanceMethod = Distance.EUCLIDIAN;
        private double maxDistanceThreshold = Double.MAX_VALUE;
        private double minClustersThreshold = 2.0d;

        /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerative$TrainingParameters$Distance.class */
        public enum Distance {
            EUCLIDIAN,
            MANHATTAN,
            MAXIMUM
        }

        /* loaded from: input_file:com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerative$TrainingParameters$Linkage.class */
        public enum Linkage {
            AVERAGE,
            SINGLE,
            COMPLETE
        }

        public Linkage getLinkageMethod() {
            return this.linkageMethod;
        }

        public void setLinkageMethod(Linkage linkage) {
            this.linkageMethod = linkage;
        }

        public Distance getDistanceMethod() {
            return this.distanceMethod;
        }

        public void setDistanceMethod(Distance distance) {
            this.distanceMethod = distance;
        }

        public double getMaxDistanceThreshold() {
            return this.maxDistanceThreshold;
        }

        public void setMaxDistanceThreshold(double d) {
            this.maxDistanceThreshold = d;
        }

        public double getMinClustersThreshold() {
            return this.minClustersThreshold;
        }

        public void setMinClustersThreshold(double d) {
            this.minClustersThreshold = d;
        }
    }

    protected HierarchicalAgglomerative(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
        this.parallelized = true;
        this.streamExecutor = new ForkJoinStream(this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    protected HierarchicalAgglomerative(String str, Configuration configuration) {
        super(str, configuration);
        this.parallelized = true;
        this.streamExecutor = new ForkJoinStream(this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public void setParallelized(boolean z) {
        this.parallelized = z;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    protected void _predict(Dataframe dataframe) {
        _predictDatasetParallel(dataframe, this.knowledgeBase.getStorageEngine(), this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable
    public PredictParallelizable.Prediction _predictRecord(Record record) {
        Map<Integer, Cluster> clusterMap = ((ModelParameters) this.knowledgeBase.getModelParameters()).getClusterMap();
        AssociativeArray associativeArray = new AssociativeArray();
        for (Map.Entry<Integer, Cluster> entry : clusterMap.entrySet()) {
            associativeArray.put(entry.getKey(), Double.valueOf(calculateDistance(record, entry.getValue().getCentroid())));
        }
        Descriptives.normalize(associativeArray);
        return new PredictParallelizable.Prediction(getSelectedClusterFromDistances(associativeArray), associativeArray);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        Set<Object> goldStandardClasses = ((ModelParameters) this.knowledgeBase.getModelParameters()).getGoldStandardClasses();
        Iterator<Record> it = dataframe.iterator();
        while (it.hasNext()) {
            Object y = it.next().getY();
            if (y != null) {
                goldStandardClasses.add(y);
            }
        }
        calculateClusters(dataframe);
        clearClusters();
    }

    private double calculateDistance(Record record, Record record2) {
        double maximum;
        TrainingParameters.Distance distanceMethod = ((TrainingParameters) this.knowledgeBase.getTrainingParameters()).getDistanceMethod();
        if (distanceMethod == TrainingParameters.Distance.EUCLIDIAN) {
            maximum = Distance.euclidean(record.getX(), record2.getX());
        } else if (distanceMethod == TrainingParameters.Distance.MANHATTAN) {
            maximum = Distance.manhattan(record.getX(), record2.getX());
        } else {
            if (distanceMethod != TrainingParameters.Distance.MAXIMUM) {
                throw new IllegalArgumentException("Unsupported Distance method.");
            }
            maximum = Distance.maximum(record.getX(), record2.getX());
        }
        return maximum;
    }

    private Object getSelectedClusterFromDistances(AssociativeArray associativeArray) {
        return MapMethods.selectMinKeyValue(associativeArray).getKey();
    }

    private void calculateClusters(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) this.knowledgeBase.getTrainingParameters();
        Map<Integer, Cluster> clusterMap = modelParameters.getClusterMap();
        StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
        Map<List<Object>, Double> bigMap = storageEngine.getBigMap("tmp_distanceArray", List.class, Double.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_CACHE, true, true);
        Map<Integer, Integer> bigMap2 = storageEngine.getBigMap("tmp_minClusterDistanceId", Integer.class, Integer.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_CACHE, true, true);
        Integer num = 0;
        for (Record record : dataframe.values()) {
            Cluster cluster = new Cluster(num.intValue());
            cluster.add(record);
            cluster.updateClusterParameters();
            clusterMap.put(num, cluster);
            num = Integer.valueOf(num.intValue() + 1);
        }
        this.streamExecutor.forEach(StreamMethods.stream(clusterMap.entrySet().stream(), isParallelized()), entry -> {
            Integer num2 = (Integer) entry.getKey();
            Cluster cluster2 = (Cluster) entry.getValue();
            for (Map.Entry entry : clusterMap.entrySet()) {
                Integer num3 = (Integer) entry.getKey();
                double calculateDistance = Objects.equals(num2, num3) ? Double.MAX_VALUE : calculateDistance(cluster2.getCentroid(), ((Cluster) entry.getValue()).getCentroid());
                bigMap.put(Arrays.asList(num2, num3), Double.valueOf(calculateDistance));
                bigMap.put(Arrays.asList(num3, num2), Double.valueOf(calculateDistance));
                Integer num4 = (Integer) bigMap2.get(num2);
                if (num4 == null || calculateDistance < ((Double) bigMap.get(Arrays.asList(num2, num4))).doubleValue()) {
                    bigMap2.put(num2, num3);
                }
            }
        });
        boolean z = true;
        while (z) {
            z = mergeClosest(bigMap2, bigMap);
            int i = 0;
            Iterator<Cluster> it = clusterMap.values().iterator();
            while (it.hasNext()) {
                if (it.next().isActive()) {
                    i++;
                }
            }
            if (i <= trainingParameters.getMinClustersThreshold()) {
                z = false;
            }
        }
        Iterator<Map.Entry<Integer, Cluster>> it2 = clusterMap.entrySet().iterator();
        while (it2.hasNext()) {
            Map.Entry<Integer, Cluster> next = it2.next();
            Integer key = next.getKey();
            Cluster value = next.getValue();
            if (value.isActive()) {
                value.updateClusterParameters();
                clusterMap.put(key, value);
            } else {
                it2.remove();
            }
        }
        storageEngine.dropBigMap("tmp_distanceArray", bigMap);
        storageEngine.dropBigMap("tmp_minClusterDistanceId", bigMap2);
    }

    private boolean mergeClosest(Map<Integer, Integer> map, Map<List<Object>, Double> map2) {
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) this.knowledgeBase.getTrainingParameters();
        Map<Integer, Cluster> clusterMap = modelParameters.getClusterMap();
        Integer num = null;
        double d = Double.MAX_VALUE;
        for (Map.Entry<Integer, Cluster> entry : clusterMap.entrySet()) {
            Integer key = entry.getKey();
            if (entry.getValue().isActive()) {
                double doubleValue = map2.get(Arrays.asList(key, map.get(key))).doubleValue();
                if (doubleValue < d) {
                    num = key;
                    d = doubleValue;
                }
            }
        }
        if (d >= trainingParameters.getMaxDistanceThreshold()) {
            return false;
        }
        Integer num2 = num;
        Integer num3 = map.get(num2);
        Cluster cluster = clusterMap.get(num2);
        Cluster cluster2 = clusterMap.get(num3);
        double size = cluster.size();
        double size2 = cluster2.size();
        cluster.merge(cluster2);
        clusterMap.put(num2, cluster);
        cluster2.setActive(false);
        clusterMap.put(num3, cluster2);
        TrainingParameters.Linkage linkageMethod = trainingParameters.getLinkageMethod();
        this.streamExecutor.forEach(StreamMethods.stream(clusterMap.entrySet().stream(), isParallelized()), entry2 -> {
            Integer num4 = (Integer) entry2.getKey();
            Cluster cluster3 = (Cluster) entry2.getValue();
            if (cluster3.isActive()) {
                double min = Objects.equals(num2, num4) ? Double.MAX_VALUE : linkageMethod == TrainingParameters.Linkage.SINGLE ? Math.min(((Double) map2.get(Arrays.asList(num2, num4))).doubleValue(), ((Double) map2.get(Arrays.asList(num3, num4))).doubleValue()) : linkageMethod == TrainingParameters.Linkage.COMPLETE ? Math.max(((Double) map2.get(Arrays.asList(num2, num4))).doubleValue(), ((Double) map2.get(Arrays.asList(num3, num4))).doubleValue()) : linkageMethod == TrainingParameters.Linkage.AVERAGE ? ((((Double) map2.get(Arrays.asList(num2, num4))).doubleValue() * size) + (((Double) map2.get(Arrays.asList(num3, num4))).doubleValue() * size2)) / (size + size2) : calculateDistance(cluster.getCentroid(), cluster3.getCentroid());
                map2.put(Arrays.asList(num2, num4), Double.valueOf(min));
                map2.put(Arrays.asList(num4, num2), Double.valueOf(min));
            }
        });
        this.streamExecutor.forEach(StreamMethods.stream(clusterMap.entrySet().stream(), isParallelized()), entry3 -> {
            Integer num4 = (Integer) entry3.getKey();
            if (((Cluster) entry3.getValue()).isActive()) {
                Integer num5 = (Integer) map.get(num4);
                if (Objects.equals(num5, num2) || Objects.equals(num5, num3)) {
                    Integer num6 = num4;
                    for (Map.Entry entry3 : clusterMap.entrySet()) {
                        Integer num7 = (Integer) entry3.getKey();
                        if (((Cluster) entry3.getValue()).isActive() && ((Double) map2.get(Arrays.asList(num4, num7))).doubleValue() < ((Double) map2.get(Arrays.asList(num4, num6))).doubleValue()) {
                            num6 = num7;
                        }
                    }
                    map.put(num4, num6);
                }
            }
        });
        return true;
    }
}
