/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.eval;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.lang3.SerializationUtils;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.eval.ClassifierEvaluator;
import org.neuroph.eval.CrossValidationResult;
import org.neuroph.eval.ErrorEvaluator;
import org.neuroph.eval.Evaluation;
import org.neuroph.eval.EvaluationResult;
import org.neuroph.eval.Evaluator;
import org.neuroph.util.data.sample.Sampling;
import org.neuroph.util.data.sample.SubSampling;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CrossValidation {
    private static final Logger LOGGER = LoggerFactory.getLogger((String)CrossValidation.class.getName());
    private NeuralNetwork neuralNetwork;
    private DataSet dataSet;
    private Sampling sampling;
    private int numberOfFolds;
    private int foldSize;
    private final Evaluation evaluation = new Evaluation();
    private CrossValidationResult results;

    private void initialize(NeuralNetwork neuralNetwork, DataSet dataSet, int numberOfFolds) {
        this.neuralNetwork = neuralNetwork;
        this.numberOfFolds = numberOfFolds;
        this.dataSet = dataSet;
        if (neuralNetwork.getOutputsCount() == 1) {
            this.evaluation.addEvaluator(new ClassifierEvaluator.Binary(0.5));
        } else {
            this.evaluation.addEvaluator(new ClassifierEvaluator.MultiClass(dataSet.getColumnNames()));
        }
        this.evaluation.addEvaluator(new ErrorEvaluator(new MeanSquaredError()));
    }

    public CrossValidation(NeuralNetwork neuralNetwork, DataSet dataSet, int foldCount) {
        this.initialize(neuralNetwork, dataSet, foldCount);
        this.sampling = new SubSampling(foldCount);
    }

    public Sampling getSampling() {
        return this.sampling;
    }

    public void setSampling(Sampling sampling) {
        this.sampling = sampling;
    }

    public Evaluation getEvaluation() {
        return this.evaluation;
    }

    public void run() throws InterruptedException, ExecutionException {
        this.results = new CrossValidationResult();
        this.results.numberOfFolds = this.numberOfFolds;
        this.results.numberOfInstances = this.dataSet.getRows().size();
        this.dataSet.shuffle();
        this.foldSize = this.dataSet.size() / this.numberOfFolds;
        ArrayList<CrossValidationWorker> workersTasks = new ArrayList<CrossValidationWorker>();
        for (int foldIdx = 0; foldIdx < this.numberOfFolds; ++foldIdx) {
            workersTasks.add(new CrossValidationWorker(this.neuralNetwork, this.dataSet, foldIdx));
        }
        ExecutorService executor = Executors.newFixedThreadPool(4);
        List evaluationResults = executor.invokeAll(workersTasks);
        executor.shutdown();
        for (Future evaluationResult : evaluationResults) {
            this.results.addEvaluationResult((EvaluationResult)evaluationResult.get());
        }
        this.results.calculateStatistics();
    }

    public void addEvaluator(Evaluator eval) {
        this.evaluation.addEvaluator(eval);
    }

    public <T extends Evaluator> T getEvaluator(Class<T> type) {
        return this.evaluation.getEvaluator(type);
    }

    public CrossValidationResult getResult() {
        return this.results;
    }

    private class CrossValidationWorker
    implements Callable<EvaluationResult> {
        private final NeuralNetwork neuralNetwork;
        private final DataSet dataSet;
        private final int foldIndex;

        public CrossValidationResult getResults() {
            return CrossValidation.this.results;
        }

        public Evaluation getEvaluation() {
            return CrossValidation.this.evaluation;
        }

        public CrossValidationWorker(NeuralNetwork neuralNetwork, DataSet dataSet, int foldIndex) {
            this.neuralNetwork = neuralNetwork;
            this.dataSet = dataSet;
            this.foldIndex = foldIndex;
        }

        @Override
        public EvaluationResult call() {
            NeuralNetwork neuralNet = (NeuralNetwork)SerializationUtils.clone((Serializable)this.neuralNetwork);
            DataSet trainingSet = new DataSet(this.dataSet.size() - CrossValidation.this.foldSize);
            DataSet testSet = new DataSet(CrossValidation.this.foldSize);
            int startIndex = CrossValidation.this.foldSize * this.foldIndex;
            int endIndex = CrossValidation.this.foldSize * (this.foldIndex + 1);
            for (int i = 0; i < this.dataSet.size(); ++i) {
                if (i >= startIndex && i < endIndex) {
                    testSet.add(this.dataSet.getRowAt(i));
                    continue;
                }
                trainingSet.add(this.dataSet.getRowAt(i));
            }
            neuralNet.learn(trainingSet);
            EvaluationResult evaluationResult = new EvaluationResult();
            evaluationResult.setNeuralNetwork(neuralNet);
            evaluationResult = CrossValidation.this.evaluation.evaluateDataSet(neuralNet, testSet);
            return evaluationResult;
        }
    }
}

