package recunn.datasets;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import recunn.datastructs.DataSequence;
import recunn.datastructs.DataSet;
import recunn.datastructs.DataStep;
import recunn.loss.LossMultiDimensionalBinary;
import recunn.loss.LossSumOfSquares;
import recunn.model.Model;
import recunn.model.Nonlinearity;
import recunn.model.SigmoidUnit;

/* loaded from: input_file:recunn/datasets/EmbeddedReberGrammar.class */
public class EmbeddedReberGrammar extends DataSet {

    /* loaded from: input_file:recunn/datasets/EmbeddedReberGrammar$State.class */
    public static class State {
        public Transition[] transitions;

        public State(Transition[] transitionArr) {
            this.transitions = transitionArr;
        }
    }

    /* loaded from: input_file:recunn/datasets/EmbeddedReberGrammar$Transition.class */
    public static class Transition {
        public int next_state_id;
        public int token;

        public Transition(int i, int i2) {
            this.next_state_id = i;
            this.token = i2;
        }
    }

    public EmbeddedReberGrammar(Random random) throws Exception {
        this.inputDimension = 7;
        this.outputDimension = 7;
        this.lossTraining = new LossSumOfSquares();
        this.lossReporting = new LossMultiDimensionalBinary();
        this.training = generateSequences(random, 1000);
        this.validation = generateSequences(random, 1000);
        this.testing = generateSequences(random, 1000);
    }

    public static List<DataSequence> generateSequences(Random random, int i) {
        ArrayList arrayList = new ArrayList();
        State[] stateArr = {new State(new Transition[]{new Transition(1, 0)}), new State(new Transition[]{new Transition(2, 1), new Transition(11, 2)}), new State(new Transition[]{new Transition(3, 0)}), new State(new Transition[]{new Transition(4, 1), new Transition(9, 2)}), new State(new Transition[]{new Transition(4, 3), new Transition(5, 4)}), new State(new Transition[]{new Transition(6, 3), new Transition(9, 4)}), new State(new Transition[]{new Transition(7, 6)}), new State(new Transition[]{new Transition(8, 1)}), new State(new Transition[]{new Transition(0, 6)}), new State(new Transition[]{new Transition(9, 1), new Transition(10, 5)}), new State(new Transition[]{new Transition(5, 2), new Transition(6, 5)}), new State(new Transition[]{new Transition(12, 0)}), new State(new Transition[]{new Transition(13, 1), new Transition(17, 2)}), new State(new Transition[]{new Transition(13, 3), new Transition(14, 4)}), new State(new Transition[]{new Transition(15, 3), new Transition(17, 4)}), new State(new Transition[]{new Transition(16, 6)}), new State(new Transition[]{new Transition(8, 2)}), new State(new Transition[]{new Transition(17, 1), new Transition(18, 5)}), new State(new Transition[]{new Transition(14, 2), new Transition(15, 5)})};
        for (int i2 = 0; i2 < i; i2++) {
            ArrayList arrayList2 = new ArrayList();
            int i3 = 0;
            while (true) {
                int i4 = -1;
                if (stateArr[i3].transitions.length == 1) {
                    i4 = 0;
                } else if (stateArr[i3].transitions.length == 2) {
                    i4 = random.nextInt(2);
                }
                double[] dArr = new double[7];
                dArr[stateArr[i3].transitions[i4].token] = 1.0d;
                i3 = stateArr[i3].transitions[i4].next_state_id;
                if (i3 == 0) {
                    break;
                }
                double[] dArr2 = new double[7];
                for (int i5 = 0; i5 < stateArr[i3].transitions.length; i5++) {
                    dArr2[stateArr[i3].transitions[i5].token] = 1.0d;
                }
                arrayList2.add(new DataStep(dArr, dArr2));
            }
            arrayList.add(new DataSequence(arrayList2));
        }
        return arrayList;
    }

    @Override // recunn.datastructs.DataSet
    public void DisplayReport(Model model, Random random) throws Exception {
    }

    @Override // recunn.datastructs.DataSet
    public Nonlinearity getModelOutputUnitToUse() {
        return new SigmoidUnit();
    }
}
