/*
 * Decompiled with CFR 0.152.
 */
package recunn.datasets;

import java.io.File;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import recunn.autodiff.Graph;
import recunn.datastructs.DataSequence;
import recunn.datastructs.DataSet;
import recunn.datastructs.DataStep;
import recunn.loss.LossSoftmax;
import recunn.matrix.Matrix;
import recunn.model.LinearUnit;
import recunn.model.Model;
import recunn.model.Nonlinearity;
import recunn.util.Util;

public class TextGenerationUnbroken
extends DataSet {
    private static final long serialVersionUID = 1L;
    public static int reportSequenceLength = 100;
    public static boolean reportPerplexity = true;
    private static Map<String, Integer> charToIndex = new HashMap<String, Integer>();
    private static Map<Integer, String> indexToChar = new HashMap<Integer, String>();
    private static int dimension;

    public static String generateText(Model model, int steps, boolean argmax, double temperature, Random rng) throws Exception {
        Matrix start = new Matrix(dimension);
        model.resetState();
        Graph g = new Graph(false);
        Matrix input = start.clone();
        String result = "";
        for (int s = 0; s < steps; ++s) {
            Matrix logprobs = model.forward(input, g);
            Matrix probs = LossSoftmax.getSoftmaxProbs(logprobs, temperature);
            int indxChosen = -1;
            if (argmax) {
                double high = Double.NEGATIVE_INFINITY;
                for (int i = 0; i < probs.w.length; ++i) {
                    if (!(probs.w[i] > high)) continue;
                    high = probs.w[i];
                    indxChosen = i;
                }
            } else {
                indxChosen = Util.pickIndexFromRandomVector(probs, rng);
            }
            String ch = indexToChar.get(indxChosen);
            result = result + ch;
            for (int i = 0; i < input.w.length; ++i) {
                input.w[i] = 0.0;
            }
            input.w[indxChosen] = 1.0;
        }
        result = result.replace("\n", "\"\n\t\"");
        return result;
    }

    public TextGenerationUnbroken(String path, int totalSequences, int sequenceMinLength, int sequenceMaxLength, Random rng) throws Exception {
        System.out.println("Text generation task");
        System.out.println("loading " + path + "...");
        File file = new File(path);
        List<String> lines_ = Files.readAllLines(file.toPath(), Charset.defaultCharset());
        String text = "";
        for (String line : lines_) {
            text = text + line + "\n";
        }
        HashSet<String> chars = new HashSet<String>();
        int id = 0;
        System.out.println("Characters:");
        System.out.print("\t");
        for (int i = 0; i < text.length(); ++i) {
            String ch = text.charAt(i) + "";
            if (chars.contains(ch)) continue;
            if (ch.equals("\n")) {
                System.out.print("\\n");
            } else {
                System.out.print(ch);
            }
            chars.add(ch);
            charToIndex.put(ch, id);
            indexToChar.put(id, ch);
            ++id;
        }
        System.out.println("");
        dimension = chars.size();
        ArrayList<DataSequence> sequences = new ArrayList<DataSequence>();
        for (int s = 0; s < totalSequences; ++s) {
            ArrayList<double[]> vecs = new ArrayList<double[]>();
            int len = rng.nextInt(sequenceMaxLength - sequenceMinLength + 1) + sequenceMinLength;
            int start = rng.nextInt(text.length() - len);
            for (int i = 0; i < len; ++i) {
                String ch = text.charAt(i + start) + "";
                int index = charToIndex.get(ch);
                double[] vec = new double[dimension];
                vec[index] = 1.0;
                vecs.add(vec);
            }
            DataSequence sequence = new DataSequence();
            for (int i = 0; i < vecs.size() - 1; ++i) {
                sequence.steps.add(new DataStep((double[])vecs.get(i), (double[])vecs.get(i + 1)));
            }
            sequences.add(sequence);
        }
        System.out.println("Total unique chars = " + chars.size());
        this.training = sequences;
        this.lossTraining = new LossSoftmax();
        this.lossReporting = new LossSoftmax();
        this.inputDimension = ((DataSequence)sequences.get((int)0)).steps.get((int)0).input.w.length;
        int loc = 0;
        while (((DataSequence)sequences.get((int)0)).steps.get((int)loc).targetOutput == null) {
            ++loc;
        }
        this.outputDimension = ((DataSequence)sequences.get((int)0)).steps.get((int)loc).targetOutput.w.length;
    }

    @Override
    public void DisplayReport(Model model, Random rng) throws Exception {
        double[] temperatures;
        System.out.println("========================================");
        System.out.println("REPORT:");
        if (reportPerplexity) {
            System.out.println("\ncalculating perplexity over entire data set...");
            double perplexity = LossSoftmax.calculateMedianPerplexity(model, this.training);
            System.out.println("\nMedian Perplexity = " + String.format("%.4f", perplexity));
        }
        for (double temperature : temperatures = new double[]{1.0, 0.75, 0.5, 0.25, 0.1}) {
            System.out.println("\nTemperature " + temperature + " prediction:");
            String guess = TextGenerationUnbroken.generateText(model, reportSequenceLength, false, temperature, rng);
            System.out.println("\t\"..." + guess + "...\"");
        }
        System.out.println("\nArgmax prediction:");
        String guess = TextGenerationUnbroken.generateText(model, reportSequenceLength, true, 1.0, rng);
        System.out.println("\t\"..." + guess + "...\"");
        System.out.println("========================================");
    }

    @Override
    public Nonlinearity getModelOutputUnitToUse() {
        return new LinearUnit();
    }
}

