/*
 * Decompiled with CFR 0.152.
 */
package recunn.datasets;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import recunn.datastructs.DataSequence;
import recunn.datastructs.DataSet;
import recunn.datastructs.DataStep;
import recunn.loss.LossMultiDimensionalBinary;
import recunn.loss.LossSumOfSquares;
import recunn.matrix.Matrix;
import recunn.model.Model;
import recunn.model.Nonlinearity;
import recunn.model.SigmoidUnit;

public class SequentialParity
extends DataSet {
    public SequentialParity(Random r, int total_sequences, int max_sequence_length_train, int max_sequence_length_test) {
        this.inputDimension = 1;
        this.outputDimension = 1;
        this.lossTraining = new LossSumOfSquares();
        this.lossReporting = new LossMultiDimensionalBinary();
        this.training = SequentialParity.generateSequences(r, total_sequences, max_sequence_length_train);
        this.validation = SequentialParity.generateSequences(r, total_sequences, max_sequence_length_train);
        this.testing = SequentialParity.generateSequences(r, total_sequences, max_sequence_length_test);
    }

    private static List<DataSequence> generateSequences(Random r, int total_sequences, int max_sequence_length) {
        ArrayList<DataSequence> result = new ArrayList<DataSequence>();
        for (int s = 0; s < total_sequences; ++s) {
            DataSequence sequence = new DataSequence();
            int tot = 0;
            int tempSequenceLength = r.nextInt(max_sequence_length) + 1;
            for (int t = 0; t < tempSequenceLength; ++t) {
                DataStep step = new DataStep();
                double[] input = new double[]{0.0};
                if (r.nextDouble() < 0.5) {
                    input[0] = 1.0;
                    ++tot;
                }
                step.input = new Matrix(input);
                double[] targetOutput = null;
                if (t == tempSequenceLength - 1) {
                    targetOutput = new double[]{tot % 2};
                    step.targetOutput = new Matrix(targetOutput);
                }
                sequence.steps.add(step);
            }
            result.add(sequence);
        }
        return result;
    }

    @Override
    public void DisplayReport(Model model, Random rng) throws Exception {
    }

    @Override
    public Nonlinearity getModelOutputUnitToUse() {
        return new SigmoidUnit();
    }
}

