/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.dataset;

import java.util.List;
import org.jdmp.core.algorithm.classification.Classifier;
import org.jdmp.core.algorithm.regression.Regressor;
import org.jdmp.core.dataset.ListDataSet;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.listmatrix.DefaultListMatrix;
import org.ujmp.core.listmatrix.ListMatrix;

public class CrossValidation {
    public static ListMatrix<Double> run(Classifier algorithm, ListDataSet dataSet) throws Exception {
        return CrossValidation.run(algorithm, dataSet, 10, 10, System.currentTimeMillis());
    }

    public static ListMatrix<Double> run(Regressor algorithm, ListDataSet dataSet, int folds, int runs, long randomSeed) throws Exception {
        DefaultListMatrix<Double> allacc = new DefaultListMatrix<Double>();
        DefaultListMatrix allfm = new DefaultListMatrix();
        DefaultListMatrix allsens = new DefaultListMatrix();
        DefaultListMatrix allspec = new DefaultListMatrix();
        DefaultListMatrix allprec = new DefaultListMatrix();
        DefaultListMatrix allrec = new DefaultListMatrix();
        DefaultListMatrix allrmse = new DefaultListMatrix();
        for (int run = 0; run < runs; ++run) {
            DefaultListMatrix acc = new DefaultListMatrix();
            DefaultListMatrix fm = new DefaultListMatrix();
            DefaultListMatrix sens = new DefaultListMatrix();
            DefaultListMatrix spec = new DefaultListMatrix();
            DefaultListMatrix prec = new DefaultListMatrix();
            DefaultListMatrix rec = new DefaultListMatrix();
            DefaultListMatrix rmse = new DefaultListMatrix();
            System.out.print("F-Measure (macro) in run " + run + ":\t");
            for (int fold = 0; fold < folds; ++fold) {
                List<ListDataSet> dss = dataSet.splitForCV(folds, fold, randomSeed + (long)run);
                ListDataSet train = dss.get(0);
                ListDataSet test = dss.get(1);
                algorithm.reset();
                algorithm.trainAll(train);
                algorithm.predictAll(train);
                algorithm.predictAll(test);
                acc.add(test.getAccuracy());
                fm.add(test.getAsDouble("FMeasureMacro"));
                sens.add(test.getAsDouble("Sensitivity"));
                spec.add(test.getAsDouble("Specificity"));
                prec.add(test.getAsDouble("Precision"));
                rec.add(test.getAsDouble("Recall"));
                rmse.add(test.getAsDouble("RMSE"));
                System.out.print(test.getAsDouble("FMeasureMacro") + "\t");
            }
            System.out.println();
            double meanacc = acc.getMeanValue();
            allacc.add(meanacc);
            double meanfm = fm.getMeanValue();
            allfm.add(meanfm);
            double meansens = sens.getMeanValue();
            allsens.add(meansens);
            double meanspec = spec.getMeanValue();
            allspec.add(meanspec);
            double meanprec = prec.getMeanValue();
            allprec.add(meanprec);
            double meanrec = rec.getMeanValue();
            allrec.add(meanrec);
            double meanrmse = rmse.getMeanValue();
            allrmse.add(meanrmse);
            System.out.println("Average Accuracy in run " + run + ":\t" + meanacc);
            System.out.println("Average F-Measure in run " + run + ":\t" + meanfm);
            System.out.println("Average Sensitivity in run " + run + ":\t" + meansens);
            System.out.println("Average Specificity in run " + run + ":\t" + meanspec);
            System.out.println("Average Precision in run " + run + ":\t" + meanprec);
            System.out.println("Average Recall in run " + run + ":\t" + meanrec);
            System.out.println("Average RMSE in run " + run + ":\t" + meanrmse);
        }
        if (allacc.size() > 1) {
            Matrix stdacc = allacc.std(Calculation.Ret.NEW, 0, false, true);
            System.out.println("Accuracy: " + allacc.getMeanValue() + "+-" + stdacc.doubleValue());
            Matrix stdfm = allfm.std(Calculation.Ret.NEW, 0, false, true);
            System.out.println("F-Measure (macro): " + allfm.getMeanValue() + "+-" + stdfm.doubleValue());
            Matrix stdsens = allsens.std(Calculation.Ret.NEW, 0, false, true);
            System.out.println("Sensitivity: " + allsens.getMeanValue() + "+-" + stdsens.doubleValue());
            Matrix stdspec = allspec.std(Calculation.Ret.NEW, 0, false, true);
            System.out.println("Specificity: " + allspec.getMeanValue() + "+-" + stdspec.doubleValue());
            Matrix stdprec = allprec.std(Calculation.Ret.NEW, 0, false, true);
            System.out.println("Precision: " + allprec.getMeanValue() + "+-" + stdprec.doubleValue());
            Matrix stdrec = allrec.std(Calculation.Ret.NEW, 0, false, true);
            System.out.println("Recall: " + allrec.getMeanValue() + "+-" + stdrec.doubleValue());
            Matrix stdrmse = allrmse.std(Calculation.Ret.NEW, 0, false, true);
            System.out.println("RMSE: " + allrmse.getMeanValue() + "+-" + stdrmse.doubleValue());
        } else {
            System.out.println("Accuracy: " + allacc.getMeanValue());
            System.out.println("F-Measure (macro): " + allfm.getMeanValue());
            System.out.println("Sensitivity: " + allsens.getMeanValue());
            System.out.println("Specificity: " + allspec.getMeanValue());
            System.out.println("Precision: " + allprec.getMeanValue());
            System.out.println("Recall: " + allrec.getMeanValue());
            System.out.println("RMSE: " + allrmse.getMeanValue());
        }
        return allacc;
    }
}

