package recunn.datasets;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import recunn.datastructs.DataSequence;
import recunn.datastructs.DataSet;
import recunn.datastructs.DataStep;
import recunn.loss.LossMultiDimensionalBinary;
import recunn.loss.LossSumOfSquares;
import recunn.matrix.Matrix;
import recunn.model.Model;
import recunn.model.Nonlinearity;
import recunn.model.SigmoidUnit;

/* loaded from: input_file:recunn/datasets/SequentialParity.class */
public class SequentialParity extends DataSet {
    public SequentialParity(Random random, int i, int i2, int i3) {
        this.inputDimension = 1;
        this.outputDimension = 1;
        this.lossTraining = new LossSumOfSquares();
        this.lossReporting = new LossMultiDimensionalBinary();
        this.training = generateSequences(random, i, i2);
        this.validation = generateSequences(random, i, i2);
        this.testing = generateSequences(random, i, i3);
    }

    private static List<DataSequence> generateSequences(Random random, int i, int i2) {
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < i; i3++) {
            DataSequence dataSequence = new DataSequence();
            int i4 = 0;
            int nextInt = random.nextInt(i2) + 1;
            for (int i5 = 0; i5 < nextInt; i5++) {
                DataStep dataStep = new DataStep();
                double[] dArr = {0.0d};
                if (random.nextDouble() < 0.5d) {
                    dArr[0] = 1.0d;
                    i4++;
                }
                dataStep.input = new Matrix(dArr);
                if (i5 == nextInt - 1) {
                    dataStep.targetOutput = new Matrix(new double[]{i4 % 2});
                }
                dataSequence.steps.add(dataStep);
            }
            arrayList.add(dataSequence);
        }
        return arrayList;
    }

    @Override // recunn.datastructs.DataSet
    public void DisplayReport(Model model, Random random) throws Exception {
    }

    @Override // recunn.datastructs.DataSet
    public Nonlinearity getModelOutputUnitToUse() {
        return new SigmoidUnit();
    }
}
