package recunn.examples;

import java.util.Random;
import recunn.datasets.bAbI;
import recunn.trainer.Trainer;
import recunn.util.NeuralNetworkHelper;

/* loaded from: input_file:recunn/examples/ExampleQuestionAnswering.class */
public class ExampleQuestionAnswering {
    public static void main(String[] strArr) throws Exception {
        Random random = new Random();
        double[] dArr = new double[bAbI.TASK_NAMES.length];
        for (int i = 0; i < 1; i++) {
            for (int i2 = 0; i2 < bAbI.TASK_NAMES.length; i2++) {
                int i3 = i2 + 1;
                System.out.println("\n==============================================================");
                System.out.println("bAbI experiment " + (i + 1) + " of 1");
                System.out.println("Task #" + i3 + ": " + bAbI.TASK_NAMES[i2] + "\n");
                bAbI babi = new bAbI(i3, 1000, false, random);
                double train = Trainer.train(50, 0.005d, NeuralNetworkHelper.makeLstm(babi.inputDimension, 10, 1, babi.outputDimension, babi.getModelOutputUnitToUse(), 0.08d, random), babi, 10, random);
                int i4 = i2;
                dArr[i4] = dArr[i4] + train;
                System.out.println("\nFINAL: " + String.format("%.1f", Double.valueOf(100.0d * (1.0d - train))) + "% accuracy");
            }
        }
        System.out.println("\n\n==============================================================");
        System.out.println("SUMMED RESULTS:");
        for (int i5 = 0; i5 < bAbI.TASK_NAMES.length; i5++) {
            System.out.println("\t" + String.format("%.1f", Double.valueOf(100.0d * (1.0d - (dArr[i5] / 1)))) + "% avg. accuracy on #" + (i5 + 1) + ": " + bAbI.TASK_NAMES[i5]);
        }
    }
}
