/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.modelselection.metrics;

import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.AbstractMetrics;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;

public class ClusteringMetrics
extends AbstractMetrics {
    private static final long serialVersionUID = 1L;
    private double purity = 0.0;
    private double NMI = 0.0;

    public Double getPurity() {
        return this.purity;
    }

    public Double getNMI() {
        return this.NMI;
    }

    public ClusteringMetrics(Dataframe predictedData) {
        super(predictedData);
        int n = predictedData.size();
        HashSet<Object> clusterIdSet = new HashSet<Object>();
        HashSet<Object> goldStandardClassesSet = new HashSet<Object>();
        for (Record r : predictedData) {
            Object y = r.getY();
            if (y != null) {
                goldStandardClassesSet.add(y);
            }
            clusterIdSet.add(r.getYPredicted());
        }
        if (!goldStandardClassesSet.isEmpty()) {
            HashMap<Object, Double> ctMap = new HashMap<Object, Double>();
            HashMap<Object, Double> countOfW = new HashMap<Object, Double>();
            HashMap<Object, Double> countOfC = new HashMap<Object, Double>();
            for (Object e : clusterIdSet) {
                countOfW.put(e, 0.0);
                for (Object e2 : goldStandardClassesSet) {
                    ctMap.put(Arrays.asList(e, e2), 0.0);
                    countOfC.put(e2, 0.0);
                }
            }
            for (Record record : predictedData) {
                Object clusterId = record.getYPredicted();
                Object object = record.getY();
                List<Object> tpk = Arrays.asList(clusterId, object);
                ctMap.put(tpk, (Double)ctMap.get(tpk) + 1.0);
                countOfW.put(clusterId, (Double)countOfW.get(clusterId) + 1.0);
                countOfC.put(object, (Double)countOfC.get(object) + 1.0);
            }
            double logN = Math.log(n);
            double Iwc = 0.0;
            for (Object e : clusterIdSet) {
                double maxCounts = Double.NEGATIVE_INFINITY;
                for (Object e3 : goldStandardClassesSet) {
                    Object[] objectArray = new Object[]{e, e3};
                    List<Object> tpk = Arrays.asList(objectArray);
                    double Nwc = (Double)ctMap.get(tpk);
                    if (Nwc > maxCounts) {
                        maxCounts = Nwc;
                    }
                    if (!(Nwc > 0.0)) continue;
                    Iwc += Nwc / (double)n * (Math.log(Nwc) - Math.log((Double)countOfC.get(e3)) - Math.log((Double)countOfW.get(e)) + logN);
                }
                this.purity += maxCounts;
            }
            double entropyW = 0.0;
            for (Double Nw : countOfW.values()) {
                entropyW -= Nw / (double)n * (Math.log(Nw) - logN);
            }
            double entropyC = 0.0;
            for (Double d : countOfW.values()) {
                entropyC -= d / (double)n * (Math.log(d) - logN);
            }
            this.purity /= (double)n;
            this.NMI = Iwc / ((entropyW + entropyC) / 2.0);
        }
    }

    public ClusteringMetrics(List<ClusteringMetrics> validationMetricsList) {
        super(validationMetricsList);
        int k = 0;
        for (ClusteringMetrics vmSample : validationMetricsList) {
            this.NMI += vmSample.getNMI().doubleValue();
            this.purity += vmSample.getPurity().doubleValue();
            ++k;
        }
        if (k > 0) {
            this.NMI /= (double)k;
            this.purity /= (double)k;
        }
    }

    public String toString() {
        String sep = System.lineSeparator();
        StringBuilder sb = new StringBuilder();
        sb.append(this.getClass().getSimpleName()).append(":").append(sep);
        sb.append("purity=").append(this.purity).append(sep);
        sb.append("NMI=").append(this.NMI).append(sep);
        return sb.toString();
    }
}

