/*
 * Decompiled with CFR 0.152.
 */
package recunn.datasets;

import java.io.File;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import recunn.datastructs.DataSequence;
import recunn.datastructs.DataSet;
import recunn.datastructs.DataStep;
import recunn.loss.LossArgMax;
import recunn.loss.LossSoftmax;
import recunn.model.LinearUnit;
import recunn.model.Model;
import recunn.model.Nonlinearity;

public class bAbI
extends DataSet {
    public static final String[] TASK_NAMES = new String[]{"Single Supporting Fact", "Two Supporting Facts", "Three Supporting Facts", "Two Arg. Relations", "Three Arg. Relations", "Yes/No Questions", "Counting", "Lists/Sets", "Simple Negation", "Indefinite Knowledge", "Basic Coreference", "Conjunction", "Compound Coreference", "Time Reasoning", "Basic Deduction", "Basic Induction", "Positional Reasoning", "Size Reasoning", "Path Finding", "Agent's Motivations"};
    List<String> inputVocab = new ArrayList<String>();
    List<String> outputVocab = new ArrayList<String>();

    public static void main(String[] args) throws Exception {
        System.out.println("testing...");
        Random rng = new Random();
        bAbI data = new bAbI(3, 100, true, rng);
        System.out.println("done.");
    }

    public bAbI(int setId, int totalExamples, boolean onlySupportingFacts, Random rng) throws Exception {
        File folder = new File("datasets/bAbI/en/");
        ArrayList<String> fileNamesTrain = new ArrayList<String>();
        ArrayList<String> fileNamesTest = new ArrayList<String>();
        for (File fileEntry : folder.listFiles()) {
            String path = fileEntry.getPath();
            if (path.contains("train")) {
                if (!path.contains("qa" + setId + "_")) continue;
                fileNamesTrain.add(path);
                continue;
            }
            if (path.contains("test")) {
                if (!path.contains("qa" + setId + "_")) continue;
                fileNamesTest.add(path);
                continue;
            }
            throw new Exception("Unknown file type");
        }
        List<Story> storiesTrain = this.getStories(fileNamesTrain, onlySupportingFacts);
        List<Story> storiesTest = this.getStories(fileNamesTest, onlySupportingFacts);
        while (storiesTrain.size() > totalExamples) {
            storiesTrain.remove(rng.nextInt(storiesTrain.size()));
        }
        while (storiesTest.size() > totalExamples) {
            storiesTest.remove(rng.nextInt(storiesTest.size()));
        }
        this.configureVocab(storiesTrain, storiesTest);
        this.training = this.getSequences(storiesTrain);
        this.testing = this.getSequences(storiesTest);
        this.validation = null;
        this.lossTraining = new LossSoftmax();
        this.lossReporting = new LossArgMax();
        this.inputDimension = ((DataSequence)this.training.get((int)0)).steps.get((int)0).input.w.length;
        int loc = 0;
        while (((DataSequence)this.training.get((int)0)).steps.get((int)loc).targetOutput == null) {
            ++loc;
        }
        this.outputDimension = ((DataSequence)this.training.get((int)0)).steps.get((int)loc).targetOutput.w.length;
    }

    List<Story> getStories(List<String> fileNames, boolean onlySupportingFacts) throws Exception {
        ArrayList<Statement> statements = new ArrayList<Statement>();
        for (String fileName : fileNames) {
            File file = new File(fileName);
            List<String> lines = Files.readAllLines(file.toPath(), Charset.defaultCharset());
            for (String string : lines) {
                statements.add(new Statement(string));
            }
        }
        ArrayList<Story> stories = new ArrayList<Story>();
        int prevNum = 0;
        ArrayList<Statement> storyList = new ArrayList<Statement>();
        boolean containsQuestion = false;
        int errors = 0;
        for (Statement statement : statements) {
            if (statement.lineNum < prevNum) {
                if (!containsQuestion) {
                    ++errors;
                } else {
                    Story story = new Story(storyList, onlySupportingFacts);
                    stories.add(story);
                }
                containsQuestion = false;
                storyList = new ArrayList();
            }
            if (!statement.isFact) {
                containsQuestion = true;
            }
            storyList.add(statement);
            prevNum = statement.lineNum;
        }
        Story story = new Story(storyList, onlySupportingFacts);
        stories.add(story);
        if (errors > 0) {
            System.out.println("WARNING: " + errors + " INCORRECT STORIES REMOVED.");
        }
        return stories;
    }

    private void configureVocab(List<Story> storiesTrain, List<Story> storiesTest) {
        HashSet<String> inputVocabSet = new HashSet<String>();
        HashSet<String> outputVocabSet = new HashSet<String>();
        ArrayList<Story> allStories = new ArrayList<Story>();
        allStories.addAll(storiesTrain);
        allStories.addAll(storiesTest);
        for (Story story : allStories) {
            for (Statement statement : story.statements) {
                if (statement.isFact) {
                    for (String word : statement.fact) {
                        inputVocabSet.add(word);
                    }
                    continue;
                }
                for (String word : statement.question) {
                    inputVocabSet.add(word);
                }
                outputVocabSet.add(statement.answer);
            }
        }
        for (String word : inputVocabSet) {
            this.inputVocab.add(word);
        }
        for (String word : outputVocabSet) {
            this.outputVocab.add(word);
        }
        Collections.sort(this.inputVocab);
        Collections.sort(this.outputVocab);
        System.out.println("Possible answers: ");
        for (int i = 0; i < this.outputVocab.size(); ++i) {
            System.out.println("\t[" + i + "]: " + this.outputVocab.get(i));
        }
    }

    private List<DataSequence> getSequences(List<Story> stories) {
        int inputDimension = this.inputVocab.size();
        int outputDimension = this.outputVocab.size();
        ArrayList<DataSequence> sequences = new ArrayList<DataSequence>();
        for (Story story : stories) {
            ArrayList<DataStep> steps = new ArrayList<DataStep>();
            for (Statement statement : story.statements) {
                int i;
                double[] input;
                int w;
                if (statement.isFact) {
                    for (w = 0; w < statement.fact.size(); ++w) {
                        input = new double[inputDimension];
                        for (i = 0; i < inputDimension; ++i) {
                            if (!statement.fact.get(w).equals(this.inputVocab.get(i))) continue;
                            input[i] = 1.0;
                            break;
                        }
                        steps.add(new DataStep(input, null));
                    }
                    continue;
                }
                for (w = 0; w < statement.question.size(); ++w) {
                    input = new double[inputDimension];
                    double[] targetOutput = null;
                    for (int i2 = 0; i2 < inputDimension; ++i2) {
                        if (!statement.question.get(w).equals(this.inputVocab.get(i2))) continue;
                        input[i2] = 1.0;
                        break;
                    }
                    steps.add(new DataStep(input, targetOutput));
                }
                double[] input2 = new double[inputDimension];
                double[] targetOutput = new double[outputDimension];
                for (i = 0; i < outputDimension; ++i) {
                    if (!statement.answer.equals(this.outputVocab.get(i))) continue;
                    targetOutput[i] = 1.0;
                    break;
                }
                steps.add(new DataStep(input2, targetOutput));
            }
            sequences.add(new DataSequence(steps));
        }
        return sequences;
    }

    @Override
    public void DisplayReport(Model model, Random rng) throws Exception {
    }

    @Override
    public Nonlinearity getModelOutputUnitToUse() {
        return new LinearUnit();
    }

    class Story {
        List<Statement> statements;

        public Story(List<Statement> statements, boolean onlySupportingFacts) {
            if (onlySupportingFacts) {
                HashSet<Integer> supportingFactsAndQuestions = new HashSet<Integer>();
                for (Statement statement : statements) {
                    if (statement.isFact) continue;
                    supportingFactsAndQuestions.add(statement.lineNum);
                    supportingFactsAndQuestions.addAll(statement.supportingFacts);
                }
                ArrayList<Statement> trimmed = new ArrayList<Statement>();
                for (Statement statement : statements) {
                    if (!supportingFactsAndQuestions.contains(statement.lineNum)) continue;
                    trimmed.add(statement);
                }
                this.statements = trimmed;
            } else {
                this.statements = statements;
            }
        }

        public String toString() {
            String result = "";
            for (Statement statement : this.statements) {
                result = result + statement.toString() + "\n";
            }
            return result;
        }
    }

    class Statement {
        boolean isFact = true;
        List<String> fact = new ArrayList<String>();
        List<String> question = new ArrayList<String>();
        String answer;
        List<Integer> supportingFacts = new ArrayList<Integer>();
        int lineNum;

        public Statement(String line) {
            String[] parts = line.split("\t");
            if (parts.length > 1) {
                String[] words = parts[0].replace("?", " ?").split(" ");
                this.lineNum = Integer.parseInt(words[0]);
                for (int i = 1; i < words.length; ++i) {
                    this.question.add(words[i].toLowerCase());
                }
                this.answer = parts[1].toLowerCase();
                String[] facts = parts[2].split(" ");
                for (int i = 0; i < facts.length; ++i) {
                    this.supportingFacts.add(Integer.parseInt(facts[i]));
                }
                this.isFact = false;
            } else {
                String[] words = line.replace(".", " .").split(" ");
                this.lineNum = Integer.parseInt(words[0]);
                for (int i = 1; i < words.length; ++i) {
                    this.fact.add(words[i].toLowerCase());
                }
                this.isFact = true;
            }
        }

        public String toString() {
            String result = this.lineNum + "";
            if (this.isFact) {
                for (String word : this.fact) {
                    result = result + " " + word;
                }
            } else {
                for (String word : this.question) {
                    result = result + " " + word;
                }
                result = result + " -> " + this.answer;
                for (Integer i : this.supportingFacts) {
                    result = result + " " + i;
                }
            }
            return result;
        }
    }
}

