package recunn.datasets;

import java.io.File;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import recunn.datastructs.DataSequence;
import recunn.datastructs.DataSet;
import recunn.datastructs.DataStep;
import recunn.loss.LossArgMax;
import recunn.loss.LossSoftmax;
import recunn.model.LinearUnit;
import recunn.model.Model;
import recunn.model.Nonlinearity;

/* loaded from: input_file:recunn/datasets/bAbI.class */
public class bAbI extends DataSet {
    public static final String[] TASK_NAMES = {"Single Supporting Fact", "Two Supporting Facts", "Three Supporting Facts", "Two Arg. Relations", "Three Arg. Relations", "Yes/No Questions", "Counting", "Lists/Sets", "Simple Negation", "Indefinite Knowledge", "Basic Coreference", "Conjunction", "Compound Coreference", "Time Reasoning", "Basic Deduction", "Basic Induction", "Positional Reasoning", "Size Reasoning", "Path Finding", "Agent's Motivations"};
    List<String> inputVocab = new ArrayList();
    List<String> outputVocab = new ArrayList();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:recunn/datasets/bAbI$Statement.class */
    public class Statement {
        boolean isFact;
        String answer;
        int lineNum;
        List<String> fact = new ArrayList();
        List<String> question = new ArrayList();
        List<Integer> supportingFacts = new ArrayList();

        public Statement(String str) {
            this.isFact = true;
            String[] split = str.split("\t");
            if (split.length <= 1) {
                String[] split2 = str.replace(".", " .").split(" ");
                this.lineNum = Integer.parseInt(split2[0]);
                for (int i = 1; i < split2.length; i++) {
                    this.fact.add(split2[i].toLowerCase());
                }
                this.isFact = true;
                return;
            }
            String[] split3 = split[0].replace("?", " ?").split(" ");
            this.lineNum = Integer.parseInt(split3[0]);
            for (int i2 = 1; i2 < split3.length; i2++) {
                this.question.add(split3[i2].toLowerCase());
            }
            this.answer = split[1].toLowerCase();
            for (String str2 : split[2].split(" ")) {
                this.supportingFacts.add(Integer.valueOf(Integer.parseInt(str2)));
            }
            this.isFact = false;
        }

        public String toString() {
            String str = this.lineNum + "";
            if (this.isFact) {
                Iterator<String> it = this.fact.iterator();
                while (it.hasNext()) {
                    str = str + " " + it.next();
                }
            } else {
                Iterator<String> it2 = this.question.iterator();
                while (it2.hasNext()) {
                    str = str + " " + it2.next();
                }
                str = str + " -> " + this.answer;
                Iterator<Integer> it3 = this.supportingFacts.iterator();
                while (it3.hasNext()) {
                    str = str + " " + it3.next();
                }
            }
            return str;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:recunn/datasets/bAbI$Story.class */
    public class Story {
        List<Statement> statements;

        public Story(List<Statement> list, boolean z) {
            if (!z) {
                this.statements = list;
                return;
            }
            HashSet hashSet = new HashSet();
            for (Statement statement : list) {
                if (!statement.isFact) {
                    hashSet.add(Integer.valueOf(statement.lineNum));
                    hashSet.addAll(statement.supportingFacts);
                }
            }
            ArrayList arrayList = new ArrayList();
            for (Statement statement2 : list) {
                if (hashSet.contains(Integer.valueOf(statement2.lineNum))) {
                    arrayList.add(statement2);
                }
            }
            this.statements = arrayList;
        }

        public String toString() {
            String str = "";
            Iterator<Statement> it = this.statements.iterator();
            while (it.hasNext()) {
                str = str + it.next().toString() + "\n";
            }
            return str;
        }
    }

    public static void main(String[] strArr) throws Exception {
        System.out.println("testing...");
        new bAbI(3, 100, true, new Random());
        System.out.println("done.");
    }

    public bAbI(int i, int i2, boolean z, Random random) throws Exception {
        File file = new File("datasets/bAbI/en/");
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (File file2 : file.listFiles()) {
            String path = file2.getPath();
            if (!path.contains("train")) {
                if (!path.contains("test")) {
                    throw new Exception("Unknown file type");
                }
                if (path.contains("qa" + i + "_")) {
                    arrayList2.add(path);
                }
            } else if (path.contains("qa" + i + "_")) {
                arrayList.add(path);
            }
        }
        List<Story> stories = getStories(arrayList, z);
        List<Story> stories2 = getStories(arrayList2, z);
        while (stories.size() > i2) {
            stories.remove(random.nextInt(stories.size()));
        }
        while (stories2.size() > i2) {
            stories2.remove(random.nextInt(stories2.size()));
        }
        configureVocab(stories, stories2);
        this.training = getSequences(stories);
        this.testing = getSequences(stories2);
        this.validation = null;
        this.lossTraining = new LossSoftmax();
        this.lossReporting = new LossArgMax();
        this.inputDimension = this.training.get(0).steps.get(0).input.w.length;
        int i3 = 0;
        while (this.training.get(0).steps.get(i3).targetOutput == null) {
            i3++;
        }
        this.outputDimension = this.training.get(0).steps.get(i3).targetOutput.w.length;
    }

    List<Story> getStories(List<String> list, boolean z) throws Exception {
        ArrayList<Statement> arrayList = new ArrayList();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            Iterator<String> it2 = Files.readAllLines(new File(it.next()).toPath(), Charset.defaultCharset()).iterator();
            while (it2.hasNext()) {
                arrayList.add(new Statement(it2.next()));
            }
        }
        ArrayList arrayList2 = new ArrayList();
        int i = 0;
        ArrayList arrayList3 = new ArrayList();
        boolean z2 = false;
        int i2 = 0;
        for (Statement statement : arrayList) {
            if (statement.lineNum < i) {
                if (z2) {
                    arrayList2.add(new Story(arrayList3, z));
                } else {
                    i2++;
                }
                z2 = false;
                arrayList3 = new ArrayList();
            }
            if (!statement.isFact) {
                z2 = true;
            }
            arrayList3.add(statement);
            i = statement.lineNum;
        }
        arrayList2.add(new Story(arrayList3, z));
        if (i2 > 0) {
            System.out.println("WARNING: " + i2 + " INCORRECT STORIES REMOVED.");
        }
        return arrayList2;
    }

    private void configureVocab(List<Story> list, List<Story> list2) {
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(list);
        arrayList.addAll(list2);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            for (Statement statement : ((Story) it.next()).statements) {
                if (statement.isFact) {
                    Iterator<String> it2 = statement.fact.iterator();
                    while (it2.hasNext()) {
                        hashSet.add(it2.next());
                    }
                } else {
                    Iterator<String> it3 = statement.question.iterator();
                    while (it3.hasNext()) {
                        hashSet.add(it3.next());
                    }
                    hashSet2.add(statement.answer);
                }
            }
        }
        Iterator it4 = hashSet.iterator();
        while (it4.hasNext()) {
            this.inputVocab.add((String) it4.next());
        }
        Iterator it5 = hashSet2.iterator();
        while (it5.hasNext()) {
            this.outputVocab.add((String) it5.next());
        }
        Collections.sort(this.inputVocab);
        Collections.sort(this.outputVocab);
        System.out.println("Possible answers: ");
        for (int i = 0; i < this.outputVocab.size(); i++) {
            System.out.println("\t[" + i + "]: " + this.outputVocab.get(i));
        }
    }

    private List<DataSequence> getSequences(List<Story> list) {
        int size = this.inputVocab.size();
        int size2 = this.outputVocab.size();
        ArrayList arrayList = new ArrayList();
        for (Story story : list) {
            ArrayList arrayList2 = new ArrayList();
            for (Statement statement : story.statements) {
                if (statement.isFact) {
                    for (int i = 0; i < statement.fact.size(); i++) {
                        double[] dArr = new double[size];
                        int i2 = 0;
                        while (true) {
                            if (i2 >= size) {
                                break;
                            }
                            if (statement.fact.get(i).equals(this.inputVocab.get(i2))) {
                                dArr[i2] = 1.0d;
                                break;
                            }
                            i2++;
                        }
                        arrayList2.add(new DataStep(dArr, null));
                    }
                } else {
                    for (int i3 = 0; i3 < statement.question.size(); i3++) {
                        double[] dArr2 = new double[size];
                        int i4 = 0;
                        while (true) {
                            if (i4 >= size) {
                                break;
                            }
                            if (statement.question.get(i3).equals(this.inputVocab.get(i4))) {
                                dArr2[i4] = 1.0d;
                                break;
                            }
                            i4++;
                        }
                        arrayList2.add(new DataStep(dArr2, null));
                    }
                    double[] dArr3 = new double[size];
                    double[] dArr4 = new double[size2];
                    int i5 = 0;
                    while (true) {
                        if (i5 >= size2) {
                            break;
                        }
                        if (statement.answer.equals(this.outputVocab.get(i5))) {
                            dArr4[i5] = 1.0d;
                            break;
                        }
                        i5++;
                    }
                    arrayList2.add(new DataStep(dArr3, dArr4));
                }
            }
            arrayList.add(new DataSequence(arrayList2));
        }
        return arrayList;
    }

    @Override // recunn.datastructs.DataSet
    public void DisplayReport(Model model, Random random) throws Exception {
    }

    @Override // recunn.datastructs.DataSet
    public Nonlinearity getModelOutputUnitToUse() {
        return new LinearUnit();
    }
}
