package recunn.datasets;

import java.io.File;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import recunn.autodiff.Graph;
import recunn.datastructs.DataSequence;
import recunn.datastructs.DataSet;
import recunn.datastructs.DataStep;
import recunn.loss.LossSoftmax;
import recunn.matrix.Matrix;
import recunn.model.LinearUnit;
import recunn.model.Model;
import recunn.model.Nonlinearity;
import recunn.util.Util;

/* loaded from: input_file:recunn/datasets/TextGenerationUnbroken.class */
public class TextGenerationUnbroken extends DataSet {
    private static final long serialVersionUID = 1;
    public static int reportSequenceLength = 100;
    public static boolean reportPerplexity = true;
    private static Map<String, Integer> charToIndex = new HashMap();
    private static Map<Integer, String> indexToChar = new HashMap();
    private static int dimension;

    public static String generateText(Model model, int i, boolean z, double d, Random random) throws Exception {
        Matrix matrix = new Matrix(dimension);
        model.resetState();
        Graph graph = new Graph(false);
        Matrix m3clone = matrix.m3clone();
        String str = "";
        for (int i2 = 0; i2 < i; i2++) {
            Matrix softmaxProbs = LossSoftmax.getSoftmaxProbs(model.forward(m3clone, graph), d);
            int i3 = -1;
            if (z) {
                double d2 = Double.NEGATIVE_INFINITY;
                for (int i4 = 0; i4 < softmaxProbs.w.length; i4++) {
                    if (softmaxProbs.w[i4] > d2) {
                        d2 = softmaxProbs.w[i4];
                        i3 = i4;
                    }
                }
            } else {
                i3 = Util.pickIndexFromRandomVector(softmaxProbs, random);
            }
            str = str + indexToChar.get(Integer.valueOf(i3));
            for (int i5 = 0; i5 < m3clone.w.length; i5++) {
                m3clone.w[i5] = 0.0d;
            }
            m3clone.w[i3] = 1.0d;
        }
        return str.replace("\n", "\"\n\t\"");
    }

    public TextGenerationUnbroken(String str, int i, int i2, int i3, Random random) throws Exception {
        System.out.println("Text generation task");
        System.out.println("loading " + str + "...");
        String str2 = "";
        Iterator<String> it = Files.readAllLines(new File(str).toPath(), Charset.defaultCharset()).iterator();
        while (it.hasNext()) {
            str2 = str2 + it.next() + "\n";
        }
        HashSet hashSet = new HashSet();
        int i4 = 0;
        System.out.println("Characters:");
        System.out.print("\t");
        for (int i5 = 0; i5 < str2.length(); i5++) {
            String str3 = str2.charAt(i5) + "";
            if (!hashSet.contains(str3)) {
                if (str3.equals("\n")) {
                    System.out.print("\\n");
                } else {
                    System.out.print(str3);
                }
                hashSet.add(str3);
                charToIndex.put(str3, Integer.valueOf(i4));
                indexToChar.put(Integer.valueOf(i4), str3);
                i4++;
            }
        }
        System.out.println("");
        dimension = hashSet.size();
        ArrayList arrayList = new ArrayList();
        for (int i6 = 0; i6 < i; i6++) {
            ArrayList arrayList2 = new ArrayList();
            int nextInt = random.nextInt((i3 - i2) + 1) + i2;
            int nextInt2 = random.nextInt(str2.length() - nextInt);
            for (int i7 = 0; i7 < nextInt; i7++) {
                int intValue = charToIndex.get(str2.charAt(i7 + nextInt2) + "").intValue();
                double[] dArr = new double[dimension];
                dArr[intValue] = 1.0d;
                arrayList2.add(dArr);
            }
            DataSequence dataSequence = new DataSequence();
            for (int i8 = 0; i8 < arrayList2.size() - 1; i8++) {
                dataSequence.steps.add(new DataStep((double[]) arrayList2.get(i8), (double[]) arrayList2.get(i8 + 1)));
            }
            arrayList.add(dataSequence);
        }
        System.out.println("Total unique chars = " + hashSet.size());
        this.training = arrayList;
        this.lossTraining = new LossSoftmax();
        this.lossReporting = new LossSoftmax();
        this.inputDimension = ((DataSequence) arrayList.get(0)).steps.get(0).input.w.length;
        int i9 = 0;
        while (((DataSequence) arrayList.get(0)).steps.get(i9).targetOutput == null) {
            i9++;
        }
        this.outputDimension = ((DataSequence) arrayList.get(0)).steps.get(i9).targetOutput.w.length;
    }

    @Override // recunn.datastructs.DataSet
    public void DisplayReport(Model model, Random random) throws Exception {
        System.out.println("========================================");
        System.out.println("REPORT:");
        if (reportPerplexity) {
            System.out.println("\ncalculating perplexity over entire data set...");
            System.out.println("\nMedian Perplexity = " + String.format("%.4f", Double.valueOf(LossSoftmax.calculateMedianPerplexity(model, this.training))));
        }
        for (double d : new double[]{1.0d, 0.75d, 0.5d, 0.25d, 0.1d}) {
            System.out.println("\nTemperature " + d + " prediction:");
            System.out.println("\t\"..." + generateText(model, reportSequenceLength, false, d, random) + "...\"");
        }
        System.out.println("\nArgmax prediction:");
        System.out.println("\t\"..." + generateText(model, reportSequenceLength, true, 1.0d, random) + "...\"");
        System.out.println("========================================");
    }

    @Override // recunn.datastructs.DataSet
    public Nonlinearity getModelOutputUnitToUse() {
        return new LinearUnit();
    }
}
