package recunn.datasets;

import java.io.File;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import recunn.autodiff.Graph;
import recunn.datastructs.DataSequence;
import recunn.datastructs.DataSet;
import recunn.datastructs.DataStep;
import recunn.loss.LossSoftmax;
import recunn.matrix.Matrix;
import recunn.model.LinearUnit;
import recunn.model.Model;
import recunn.model.Nonlinearity;
import recunn.util.Util;

/* loaded from: input_file:recunn/datasets/TextGeneration.class */
public class TextGeneration extends DataSet {
    private static int dimension;
    private static double[] vecStartEnd;
    private static final int START_END_TOKEN_INDEX = 0;
    public static int reportSequenceLength = 100;
    public static boolean singleWordAutocorrect = false;
    public static boolean reportPerplexity = true;
    private static Map<String, Integer> charToIndex = new HashMap();
    private static Map<Integer, String> indexToChar = new HashMap();
    private static Set<String> words = new HashSet();

    public static List<String> generateText(Model model, int i, boolean z, double d, Random random) throws Exception {
        ArrayList arrayList = new ArrayList();
        Matrix matrix = new Matrix(dimension);
        matrix.w[START_END_TOKEN_INDEX] = 1.0d;
        model.resetState();
        Graph graph = new Graph(false);
        Matrix m3clone = matrix.m3clone();
        String str = "";
        for (int i2 = START_END_TOKEN_INDEX; i2 < i; i2++) {
            Matrix softmaxProbs = LossSoftmax.getSoftmaxProbs(model.forward(m3clone, graph), d);
            if (singleWordAutocorrect) {
                Matrix ones = Matrix.ones(dimension, 1);
                try {
                    ones = singleWordAutocorrect(str);
                } catch (Exception e) {
                }
                double d2 = 0.0d;
                for (int i3 = START_END_TOKEN_INDEX; i3 < softmaxProbs.w.length; i3++) {
                    double[] dArr = softmaxProbs.w;
                    int i4 = i3;
                    dArr[i4] = dArr[i4] * ones.w[i3];
                    d2 += softmaxProbs.w[i3];
                }
                for (int i5 = START_END_TOKEN_INDEX; i5 < softmaxProbs.w.length; i5++) {
                    double[] dArr2 = softmaxProbs.w;
                    int i6 = i5;
                    dArr2[i6] = dArr2[i6] / d2;
                }
                for (int i7 = START_END_TOKEN_INDEX; i7 < softmaxProbs.w.length; i7++) {
                    if (softmaxProbs.w[i7] > 0.0d && ones.w[i7] == 0.0d) {
                        throw new Exception("Illegal transition");
                    }
                }
            }
            int i8 = -1;
            if (z) {
                double d3 = Double.NEGATIVE_INFINITY;
                for (int i9 = START_END_TOKEN_INDEX; i9 < softmaxProbs.w.length; i9++) {
                    if (softmaxProbs.w[i9] > d3) {
                        d3 = softmaxProbs.w[i9];
                        i8 = i9;
                    }
                }
            } else {
                i8 = Util.pickIndexFromRandomVector(softmaxProbs, random);
            }
            if (i8 == 0) {
                arrayList.add(str);
                str = "";
                matrix.m3clone();
                graph = new Graph(false);
                model.resetState();
                m3clone = matrix.m3clone();
            } else {
                str = str + indexToChar.get(Integer.valueOf(i8));
                for (int i10 = START_END_TOKEN_INDEX; i10 < m3clone.w.length; i10++) {
                    m3clone.w[i10] = 0.0d;
                }
                m3clone.w[i8] = 1.0d;
            }
        }
        if (!str.equals("")) {
            arrayList.add(str);
        }
        return arrayList;
    }

    private static Matrix singleWordAutocorrect(String str) throws Exception {
        String replace = str.replace("\"\n\"", " ");
        if (replace.equals("") || replace.endsWith(" ")) {
            return Matrix.ones(dimension, 1);
        }
        String[] split = replace.split(" ");
        String trim = split[split.length - 1].trim();
        if (trim.equals(" ") || trim.contains(" ")) {
            throw new Exception("unexpected");
        }
        ArrayList<String> arrayList = new ArrayList();
        for (String str2 : words) {
            if (str2.startsWith(trim)) {
                arrayList.add(str2);
            }
        }
        if (arrayList.size() == 0) {
            throw new Exception("unexpected, no matches for '" + trim + "'");
        }
        Matrix matrix = new Matrix(dimension);
        boolean z = START_END_TOKEN_INDEX;
        for (String str3 : arrayList) {
            if (str3.length() < trim.length()) {
                throw new Exception("How is match shorter than partial word?");
            }
            if (trim.equals(str3)) {
                matrix.w[charToIndex.get(" ").intValue()] = 1.0d;
                matrix.w[START_END_TOKEN_INDEX] = 1.0d;
            } else {
                matrix.w[charToIndex.get(str3.charAt(trim.length()) + "").intValue()] = 1.0d;
                z = true;
            }
        }
        if (!z) {
            matrix.w[charToIndex.get(" ").intValue()] = 1.0d;
            matrix.w[START_END_TOKEN_INDEX] = 1.0d;
        }
        return matrix;
    }

    public static String sequenceToSentence(DataSequence dataSequence) {
        String str = "\"";
        for (int i = START_END_TOKEN_INDEX; i < dataSequence.steps.size() - 1; i++) {
            DataStep dataStep = dataSequence.steps.get(i);
            int i2 = -1;
            int i3 = START_END_TOKEN_INDEX;
            while (true) {
                if (i3 >= dataStep.targetOutput.w.length) {
                    break;
                }
                if (dataStep.targetOutput.w[i3] == 1.0d) {
                    i2 = i3;
                    break;
                }
                i3++;
            }
            str = str + indexToChar.get(Integer.valueOf(i2));
        }
        return str + "\"\n";
    }

    public TextGeneration(String str) throws Exception {
        System.out.println("Text generation task");
        System.out.println("loading " + str + "...");
        List<String> readAllLines = Files.readAllLines(new File(str).toPath(), Charset.defaultCharset());
        HashSet hashSet = new HashSet();
        charToIndex.put("[START/END]", Integer.valueOf(START_END_TOKEN_INDEX));
        indexToChar.put(Integer.valueOf(START_END_TOKEN_INDEX), "[START/END]");
        int i = START_END_TOKEN_INDEX + 1;
        System.out.println("Characters:");
        System.out.print("\t");
        for (String str2 : readAllLines) {
            for (int i2 = START_END_TOKEN_INDEX; i2 < str2.length(); i2++) {
                String[] split = str2.split(" ");
                int length = split.length;
                for (int i3 = START_END_TOKEN_INDEX; i3 < length; i3++) {
                    words.add(split[i3].trim());
                }
                String str3 = str2.charAt(i2) + "";
                if (!hashSet.contains(str3)) {
                    System.out.print(str3);
                    hashSet.add(str3);
                    charToIndex.put(str3, Integer.valueOf(i));
                    indexToChar.put(Integer.valueOf(i), str3);
                    i++;
                }
            }
        }
        dimension = hashSet.size() + 1;
        vecStartEnd = new double[dimension];
        vecStartEnd[START_END_TOKEN_INDEX] = 1.0d;
        ArrayList arrayList = new ArrayList();
        int i4 = START_END_TOKEN_INDEX;
        for (String str4 : readAllLines) {
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add(vecStartEnd);
            for (int i5 = START_END_TOKEN_INDEX; i5 < str4.length(); i5++) {
                int intValue = charToIndex.get(str4.charAt(i5) + "").intValue();
                double[] dArr = new double[dimension];
                dArr[intValue] = 1.0d;
                arrayList2.add(dArr);
            }
            arrayList2.add(vecStartEnd);
            DataSequence dataSequence = new DataSequence();
            for (int i6 = START_END_TOKEN_INDEX; i6 < arrayList2.size() - 1; i6++) {
                dataSequence.steps.add(new DataStep((double[]) arrayList2.get(i6), (double[]) arrayList2.get(i6 + 1)));
                i4++;
            }
            arrayList.add(dataSequence);
        }
        System.out.println("Total unique chars = " + hashSet.size());
        System.out.println(i4 + " steps in training set.");
        this.training = arrayList;
        this.lossTraining = new LossSoftmax();
        this.lossReporting = new LossSoftmax();
        this.inputDimension = ((DataSequence) arrayList.get(START_END_TOKEN_INDEX)).steps.get(START_END_TOKEN_INDEX).input.w.length;
        int i7 = START_END_TOKEN_INDEX;
        while (((DataSequence) arrayList.get(START_END_TOKEN_INDEX)).steps.get(i7).targetOutput == null) {
            i7++;
        }
        this.outputDimension = ((DataSequence) arrayList.get(START_END_TOKEN_INDEX)).steps.get(i7).targetOutput.w.length;
    }

    @Override // recunn.datastructs.DataSet
    public void DisplayReport(Model model, Random random) throws Exception {
        System.out.println("========================================");
        System.out.println("REPORT:");
        if (reportPerplexity) {
            System.out.println("\ncalculating perplexity over entire data set...");
            System.out.println("\nMedian Perplexity = " + String.format("%.4f", Double.valueOf(LossSoftmax.calculateMedianPerplexity(model, this.training))));
        }
        double[] dArr = {1.0d, 0.75d, 0.5d, 0.25d, 0.1d};
        int length = dArr.length;
        for (int i = START_END_TOKEN_INDEX; i < length; i++) {
            double d = dArr[i];
            if (singleWordAutocorrect) {
                System.out.println("\nTemperature " + d + " prediction (with single word autocorrect):");
            } else {
                System.out.println("\nTemperature " + d + " prediction:");
            }
            List<String> generateText = generateText(model, reportSequenceLength, false, d, random);
            for (int i2 = START_END_TOKEN_INDEX; i2 < generateText.size(); i2++) {
                if (i2 == generateText.size() - 1) {
                    System.out.println("\t\"" + generateText.get(i2) + "...\"");
                } else {
                    System.out.println("\t\"" + generateText.get(i2) + "\"");
                }
            }
        }
        if (singleWordAutocorrect) {
            System.out.println("\nArgmax prediction (with single word autocorrect):");
        } else {
            System.out.println("\nArgmax prediction:");
        }
        List<String> generateText2 = generateText(model, reportSequenceLength, true, 1.0d, random);
        for (int i3 = START_END_TOKEN_INDEX; i3 < generateText2.size(); i3++) {
            if (i3 == generateText2.size() - 1) {
                System.out.println("\t\"" + generateText2.get(i3) + "...\"");
            } else {
                System.out.println("\t\"" + generateText2.get(i3) + "\"");
            }
        }
        System.out.println("========================================");
    }

    @Override // recunn.datastructs.DataSet
    public Nonlinearity getModelOutputUnitToUse() {
        return new LinearUnit();
    }
}
