/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.modelselection.splitters;

import com.datumbox.framework.common.dataobjects.FlatDataList;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.utilities.PHPMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.AbstractSplitter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

public class KFoldSplitter
extends AbstractSplitter {
    private final int k;

    public KFoldSplitter(int k) {
        this.k = k;
    }

    public KFoldSplitter(int k, Random random) {
        super(random);
        this.k = k;
    }

    @Override
    public Iterator<AbstractSplitter.Split> split(final Dataframe dataset) {
        final int n = dataset.size();
        if (this.k <= 0 || n <= this.k) {
            throw new IllegalArgumentException("Invalid number of folds.");
        }
        if (this.k == 1) {
            return Arrays.asList(new AbstractSplitter.Split(dataset.copy(), dataset.copy())).iterator();
        }
        final Integer[] ids = new Integer[n];
        int j = 0;
        for (Integer rId : dataset.index()) {
            ids[j++] = rId;
        }
        PHPMethods.shuffle(ids, this.random);
        final int foldSize = n / this.k;
        return new Iterator<AbstractSplitter.Split>(){
            private int counter = 0;

            @Override
            public boolean hasNext() {
                return this.counter < KFoldSplitter.this.k;
            }

            @Override
            public AbstractSplitter.Split next() {
                KFoldSplitter.this.logger.info("Kfold {}", (Object)this.counter);
                FlatDataList trainIds = new FlatDataList((List<Object>)new ArrayList<Object>(n - foldSize));
                FlatDataList testIds = new FlatDataList((List<Object>)new ArrayList<Object>(foldSize));
                for (int i = 0; i < n; ++i) {
                    if (this.counter * foldSize <= i && i < (this.counter + 1) * foldSize) {
                        testIds.add(ids[i]);
                        continue;
                    }
                    trainIds.add(ids[i]);
                }
                ++this.counter;
                return new AbstractSplitter.Split(dataset.getSubset(trainIds), dataset.getSubset(testIds));
            }
        };
    }
}

