/*
 * Decompiled with CFR 0.152.
 */
package Catalano.MachineLearning.Performance;

import Catalano.Core.ArraysUtil;
import Catalano.MachineLearning.Classification.IClassifier;
import Catalano.MachineLearning.Dataset.DatasetClassification;
import Catalano.MachineLearning.Performance.IValidation;
import Catalano.MachineLearning.Performance.SuppliedValidation;
import Catalano.Math.Matrix;

public class KFoldCrossValidation
implements IValidation {
    private int nFolds;
    private boolean shuffle;
    private long seed;

    public int getNumberOfFolds() {
        return this.nFolds;
    }

    public void setNumberOfFolds(int folds) {
        Math.max(folds, 2);
    }

    public boolean isShuffle() {
        return this.shuffle;
    }

    public void setShuffle(boolean shuffle) {
        this.shuffle = shuffle;
    }

    public KFoldCrossValidation() {
        this(10);
    }

    public KFoldCrossValidation(int nFolds) {
        this(nFolds, true);
    }

    public KFoldCrossValidation(int nFolds, boolean shuffle) {
        this(nFolds, shuffle, -1L);
    }

    public KFoldCrossValidation(int nFolds, boolean shuffle, long seed) {
        this.nFolds = nFolds;
        this.shuffle = shuffle;
        this.seed = seed;
    }

    @Override
    public double Run(IClassifier classifier, DatasetClassification dataset) {
        return this.Run(classifier, dataset.getInput(), dataset.getOutput());
    }

    @Override
    public double Run(IClassifier classifier, double[][] data, int[] labels) {
        if (this.nFolds > data.length) {
            throw new IllegalArgumentException("The number of folds must be less or equal than number of samples.");
        }
        double p = (double)labels.length / (double)this.nFolds - (double)(labels.length / this.nFolds);
        int parts = p > 0.5 ? labels.length / this.nFolds + 1 : labels.length / this.nFolds;
        int start = 0;
        int end = 0;
        int[] indexes = Matrix.Indices(0, labels.length);
        if (this.shuffle && this.seed == 0L) {
            ArraysUtil.Shuffle(indexes);
        } else if (this.shuffle && this.seed != 0L) {
            ArraysUtil.Shuffle(indexes, this.seed);
        }
        double mean = 0.0;
        for (int i = 0; i < this.nFolds; ++i) {
            end = i == this.nFolds - 1 ? indexes.length : (end += parts);
            int[] ind = Matrix.getColumns(indexes, Matrix.Indices(start, end));
            double[][] inputTrain = Matrix.RemoveRows(data, ind);
            int[] outputTrain = Matrix.RemoveColumns(labels, ind);
            double[][] inputTest = Matrix.getRows(data, Matrix.Indices(start, end));
            int[] outputTest = Matrix.getRows(labels, Matrix.Indices(start, end));
            classifier.Learn(inputTrain, outputTrain);
            SuppliedValidation sv = new SuppliedValidation();
            mean += sv.Run(classifier, inputTest, outputTest);
            start = end;
        }
        return mean / (double)this.nFolds;
    }
}

