package jsat.classifiers;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import jsat.classifiers.evaluation.ClassificationScore;
import jsat.datatransform.DataTransformProcess;
import jsat.exceptions.UntrainedModelException;
import jsat.math.OnLineStatistics;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/classifiers/ClassificationModelEvaluation.class */
public class ClassificationModelEvaluation {
    private Classifier classifier;
    private ClassificationDataSet dataSet;
    private boolean parallel;
    private double[][] confusionMatrix;
    private double sumOfWeights;
    private long totalTrainingTime;
    private long totalClassificationTime;
    private DataTransformProcess dtp;
    private boolean keepPredictions;
    private CategoricalResults[] predictions;
    private int[] truths;
    private double[] pointWeights;
    private OnLineStatistics errorStats;
    private Map<ClassificationScore, OnLineStatistics> scoreMap;
    private boolean keepModels;
    private Classifier[] keptModels;
    private Classifier[] warmModels;

    public ClassificationModelEvaluation(Classifier classifier, ClassificationDataSet classificationDataSet) {
        this(classifier, classificationDataSet, false);
    }

    public ClassificationModelEvaluation(Classifier classifier, ClassificationDataSet classificationDataSet, boolean z) {
        this.totalTrainingTime = 0L;
        this.totalClassificationTime = 0L;
        this.keepModels = false;
        this.classifier = classifier;
        this.dataSet = classificationDataSet;
        this.parallel = z;
        this.dtp = new DataTransformProcess();
        this.keepPredictions = false;
        this.errorStats = new OnLineStatistics();
        this.scoreMap = new LinkedHashMap();
    }

    public void setKeepModels(boolean z) {
        this.keepModels = z;
    }

    public boolean isKeepModels() {
        return this.keepModels;
    }

    public Classifier[] getKeptModels() {
        return this.keptModels;
    }

    public void setWarmModels(Classifier... classifierArr) {
        this.warmModels = classifierArr;
    }

    public void setDataTransformProcess(DataTransformProcess dataTransformProcess) {
        this.dtp = dataTransformProcess.clone();
    }

    public void evaluateCrossValidation(int i) {
        evaluateCrossValidation(i, RandomUtil.getRandom());
    }

    public void evaluateCrossValidation(int i, Random random) {
        if (i < 2) {
            throw new UntrainedModelException("Model could not be evaluated because " + i + " is < 2, and not valid for cross validation");
        }
        evaluateCrossValidation(this.dataSet.cvSet(i, random));
    }

    public void evaluateCrossValidation(List<ClassificationDataSet> list) {
        ArrayList arrayList = new ArrayList(list.size());
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(ClassificationDataSet.comineAllBut(list, i));
        }
        evaluateCrossValidation(list, arrayList);
    }

    public void evaluateCrossValidation(List<ClassificationDataSet> list, List<ClassificationDataSet> list2) {
        int classSize = this.dataSet.getClassSize();
        this.sumOfWeights = 0.0d;
        this.confusionMatrix = new double[classSize][classSize];
        this.totalTrainingTime = 0L;
        this.totalClassificationTime = 0L;
        if (this.keepModels) {
            this.keptModels = new Classifier[list.size()];
        }
        setUpResults(this.dataSet.getSampleSize());
        int sampleSize = this.dataSet.getSampleSize();
        for (int size = list.size() - 1; size >= 0; size--) {
            ClassificationDataSet classificationDataSet = list2.get(size);
            ClassificationDataSet classificationDataSet2 = list.get(size);
            evaluationWork(classificationDataSet, classificationDataSet2, size);
            int sampleSize2 = classificationDataSet2.getSampleSize();
            if (this.keepPredictions) {
                System.arraycopy(this.predictions, 0, this.predictions, sampleSize - sampleSize2, sampleSize2);
                System.arraycopy(this.truths, 0, this.truths, sampleSize - sampleSize2, sampleSize2);
                System.arraycopy(this.pointWeights, 0, this.pointWeights, sampleSize - sampleSize2, sampleSize2);
            }
            sampleSize -= sampleSize2;
        }
    }

    public void evaluateTestSet(ClassificationDataSet classificationDataSet) {
        if (this.keepModels) {
            this.keptModels = new Classifier[1];
        }
        int classSize = this.dataSet.getClassSize();
        this.sumOfWeights = 0.0d;
        this.confusionMatrix = new double[classSize][classSize];
        setUpResults(classificationDataSet.getSampleSize());
        this.totalClassificationTime = 0L;
        this.totalTrainingTime = 0L;
        evaluationWork(this.dataSet, classificationDataSet, 0);
    }

    private void evaluationWork(ClassificationDataSet classificationDataSet, ClassificationDataSet classificationDataSet2, int i) {
        DataTransformProcess clone = this.dtp.clone();
        if (clone.getNumberOfTransforms() > 0) {
            classificationDataSet = classificationDataSet.shallowClone2();
            clone.learnApplyTransforms(classificationDataSet);
        }
        Classifier mo0clone = this.classifier.mo0clone();
        long currentTimeMillis = System.currentTimeMillis();
        if (this.warmModels == null || !(mo0clone instanceof WarmClassifier)) {
            mo0clone.train(classificationDataSet, this.parallel);
        } else {
            ((WarmClassifier) mo0clone).train(classificationDataSet, this.warmModels[i], this.parallel);
        }
        this.totalTrainingTime += System.currentTimeMillis() - currentTimeMillis;
        if (this.keptModels != null) {
            this.keptModels[i] = mo0clone;
        }
        double[] dArr = new double[2];
        HashMap hashMap = new HashMap();
        Iterator<Map.Entry<ClassificationScore, OnLineStatistics>> it = this.scoreMap.entrySet().iterator();
        while (it.hasNext()) {
            ClassificationScore m33clone = it.next().getKey().m33clone();
            m33clone.prepare(this.dataSet.getPredicting());
            hashMap.put(m33clone, m33clone);
        }
        ParallelUtils.run(this.parallel, classificationDataSet2.getSampleSize(), (i2, i3) -> {
            double d = 0.0d;
            double d2 = 0.0d;
            long j = 0;
            HashSet<ClassificationScore> hashSet = new HashSet();
            Iterator it2 = hashMap.entrySet().iterator();
            while (it2.hasNext()) {
                hashSet.add(((ClassificationScore) ((Map.Entry) it2.next()).getKey()).m33clone());
            }
            for (int i2 = i2; i2 < i3; i2++) {
                DataPoint transform = clone.transform(classificationDataSet2.getDataPoint(i2));
                long currentTimeMillis2 = System.currentTimeMillis();
                CategoricalResults classify = mo0clone.classify(transform);
                j += System.currentTimeMillis() - currentTimeMillis2;
                Iterator it3 = hashSet.iterator();
                while (it3.hasNext()) {
                    ((ClassificationScore) it3.next()).addResult(classify, classificationDataSet2.getDataPointCategory(i2), transform.getWeight());
                }
                if (this.predictions != null) {
                    this.predictions[i2] = classify;
                    this.truths[i2] = classificationDataSet2.getDataPointCategory(i2);
                    this.pointWeights[i2] = transform.getWeight();
                }
                int dataPointCategory = classificationDataSet2.getDataPointCategory(i2);
                synchronized (this.confusionMatrix[dataPointCategory]) {
                    double[] dArr2 = this.confusionMatrix[dataPointCategory];
                    int mostLikely = classify.mostLikely();
                    dArr2[mostLikely] = dArr2[mostLikely] + transform.getWeight();
                }
                if (dataPointCategory == classify.mostLikely()) {
                    d += transform.getWeight();
                }
                d2 += transform.getWeight();
            }
            synchronized (this.confusionMatrix) {
                this.totalClassificationTime += j;
                this.sumOfWeights += d2;
                dArr[0] = dArr[0] + (d2 - d);
                dArr[1] = dArr[1] + d2;
                for (ClassificationScore classificationScore : hashSet) {
                    ((ClassificationScore) hashMap.get(classificationScore)).addResults(classificationScore);
                }
            }
        });
        this.errorStats.add(dArr[0] / dArr[1]);
        for (Map.Entry<ClassificationScore, OnLineStatistics> entry : this.scoreMap.entrySet()) {
            ClassificationScore m33clone2 = entry.getKey().m33clone();
            m33clone2.prepare(this.dataSet.getPredicting());
            m33clone2.addResults((ClassificationScore) hashMap.get(m33clone2));
            entry.getValue().add(m33clone2.getScore());
        }
    }

    public void addScorer(ClassificationScore classificationScore) {
        this.scoreMap.put(classificationScore, new OnLineStatistics());
    }

    public OnLineStatistics getScoreStats(ClassificationScore classificationScore) {
        return this.scoreMap.get(classificationScore);
    }

    public void keepPredictions(boolean z) {
        this.keepPredictions = z;
    }

    public boolean doseStoreResults() {
        return this.keepPredictions;
    }

    public CategoricalResults[] getPredictions() {
        return this.predictions;
    }

    public int[] getTruths() {
        return this.truths;
    }

    public double[] getPointWeights() {
        return this.pointWeights;
    }

    public double[][] getConfusionMatrix() {
        return this.confusionMatrix;
    }

    public void prettyPrintConfusionMatrix() {
        CategoricalData predicting = this.dataSet.getPredicting();
        int numOfCategories = predicting.getNumOfCategories();
        int i = 10;
        for (int i2 = 0; i2 < numOfCategories; i2++) {
            i = Math.max(i, predicting.getOptionName(i2).length() + 2);
        }
        String str = "%-" + i;
        System.out.printf(str + "s ", "Matrix");
        for (int i3 = 0; i3 < numOfCategories - 1; i3++) {
            System.out.printf(str + "s ", predicting.getOptionName(i3).toUpperCase());
        }
        System.out.printf(str + "s\n", predicting.getOptionName(numOfCategories - 1).toUpperCase());
        for (int i4 = 0; i4 < this.confusionMatrix.length; i4++) {
            System.out.printf(str + "s ", predicting.getOptionName(i4).toUpperCase());
            for (int i5 = 0; i5 < numOfCategories - 1; i5++) {
                System.out.printf(str + "f ", Double.valueOf(this.confusionMatrix[i4][i5]));
            }
            System.out.printf(str + "f\n", Double.valueOf(this.confusionMatrix[i4][numOfCategories - 1]));
        }
    }

    public void prettyPrintClassificationScores() {
        int i = 10;
        Iterator<Map.Entry<ClassificationScore, OnLineStatistics>> it = this.scoreMap.entrySet().iterator();
        while (it.hasNext()) {
            i = Math.max(i, it.next().getKey().getName().length() + 2);
        }
        String str = "%-" + i;
        for (Map.Entry<ClassificationScore, OnLineStatistics> entry : this.scoreMap.entrySet()) {
            OnLineStatistics value = entry.getValue();
            if (value.getMax() == value.getMin()) {
                System.out.printf(str + "s %-5f\n", entry.getKey().getName(), Double.valueOf(value.getMean()));
            } else {
                System.out.printf(str + "s %-5f (%-5f)\n", entry.getKey().getName(), Double.valueOf(value.getMean()), Double.valueOf(value.getStandardDeviation()));
            }
        }
    }

    public double getCorrectWeights() {
        double d = 0.0d;
        for (int i = 0; i < this.confusionMatrix.length; i++) {
            d += this.confusionMatrix[i][i];
        }
        return d;
    }

    public double getSumOfWeights() {
        return this.sumOfWeights;
    }

    public double getErrorRate() {
        return 1.0d - (getCorrectWeights() / this.sumOfWeights);
    }

    public OnLineStatistics getErrorRateStats() {
        return this.errorStats;
    }

    public long getTotalTrainingTime() {
        return this.totalTrainingTime;
    }

    public long getTotalClassificationTime() {
        return this.totalClassificationTime;
    }

    public Classifier getClassifier() {
        return this.classifier;
    }

    private void setUpResults(int i) {
        if (this.keepPredictions) {
            this.predictions = new CategoricalResults[i];
            this.truths = new int[this.predictions.length];
            this.pointWeights = new double[this.predictions.length];
        } else {
            this.predictions = null;
            this.truths = null;
            this.pointWeights = null;
        }
    }
}
