/*
 * Decompiled with CFR 0.152.
 */
package recunn.examples;

import java.util.Random;
import recunn.datasets.bAbI;
import recunn.datastructs.DataSet;
import recunn.model.NeuralNetwork;
import recunn.trainer.Trainer;
import recunn.util.NeuralNetworkHelper;

public class ExampleQuestionAnswering {
    public static void main(String[] args) throws Exception {
        Random rng = new Random();
        int hiddenDimension = 10;
        int hiddenLayers = 1;
        double learningRate = 0.005;
        double initParamsStdDev = 0.08;
        int epochsPerTask = 50;
        int experiments = 1;
        boolean onlyShowSupportingFacts = false;
        double[] losses = new double[bAbI.TASK_NAMES.length];
        for (int experiment = 0; experiment < experiments; ++experiment) {
            int task = 0;
            while (task < bAbI.TASK_NAMES.length) {
                int setId = task + 1;
                System.out.println("\n==============================================================");
                System.out.println("bAbI experiment " + (experiment + 1) + " of " + experiments);
                System.out.println("Task #" + setId + ": " + bAbI.TASK_NAMES[task] + "\n");
                int totalExamples = 1000;
                bAbI data = new bAbI(setId, totalExamples, onlyShowSupportingFacts, rng);
                NeuralNetwork nn = NeuralNetworkHelper.makeLstm(data.inputDimension, hiddenDimension, hiddenLayers, data.outputDimension, ((DataSet)data).getModelOutputUnitToUse(), initParamsStdDev, rng);
                int reportEveryNthEpoch = 10;
                double loss = Trainer.train(epochsPerTask, learningRate, nn, data, reportEveryNthEpoch, rng);
                int n = task++;
                losses[n] = losses[n] + loss;
                System.out.println("\nFINAL: " + String.format("%.1f", 100.0 * (1.0 - loss)) + "% accuracy");
            }
        }
        System.out.println("\n\n==============================================================");
        System.out.println("SUMMED RESULTS:");
        for (int task = 0; task < bAbI.TASK_NAMES.length; ++task) {
            System.out.println("\t" + String.format("%.1f", 100.0 * (1.0 - losses[task] / (double)experiments)) + "% avg. accuracy on #" + (task + 1) + ": " + bAbI.TASK_NAMES[task]);
        }
    }
}

