/*
 * Decompiled with CFR 0.152.
 */
package recunn.loss;

import java.util.ArrayList;
import java.util.List;
import recunn.autodiff.Graph;
import recunn.datastructs.DataSequence;
import recunn.datastructs.DataStep;
import recunn.loss.Loss;
import recunn.matrix.Matrix;
import recunn.model.Model;
import recunn.util.Util;

public class LossSoftmax
implements Loss {
    private static final long serialVersionUID = 1L;

    @Override
    public void backward(Matrix logprobs, Matrix targetOutput) throws Exception {
        int targetIndex = LossSoftmax.getTargetIndex(targetOutput);
        Matrix probs = LossSoftmax.getSoftmaxProbs(logprobs, 1.0);
        for (int i = 0; i < probs.w.length; ++i) {
            logprobs.dw[i] = probs.w[i];
        }
        int n = targetIndex;
        logprobs.dw[n] = logprobs.dw[n] - 1.0;
    }

    @Override
    public double measure(Matrix logprobs, Matrix targetOutput) throws Exception {
        int targetIndex = LossSoftmax.getTargetIndex(targetOutput);
        Matrix probs = LossSoftmax.getSoftmaxProbs(logprobs, 1.0);
        double cost = -Math.log(probs.w[targetIndex]);
        return cost;
    }

    public static double calculateMedianPerplexity(Model model, List<DataSequence> sequences) throws Exception {
        double temperature = 1.0;
        ArrayList<Double> ppls = new ArrayList<Double>();
        for (DataSequence seq : sequences) {
            double n = 0.0;
            double neglog2ppl = 0.0;
            Graph g = new Graph(false);
            model.resetState();
            for (DataStep step : seq.steps) {
                Matrix logprobs = model.forward(step.input, g);
                Matrix probs = LossSoftmax.getSoftmaxProbs(logprobs, temperature);
                int targetIndex = LossSoftmax.getTargetIndex(step.targetOutput);
                double probOfCorrect = probs.w[targetIndex];
                double log2prob = Math.log(probOfCorrect) / Math.log(2.0);
                neglog2ppl += -log2prob;
                n += 1.0;
            }
            double ppl = Math.pow(2.0, neglog2ppl / ((n -= 1.0) - 1.0));
            ppls.add(ppl);
        }
        return Util.median(ppls);
    }

    public static Matrix getSoftmaxProbs(Matrix logprobs, double temperature) throws Exception {
        int i;
        Matrix probs = new Matrix(logprobs.w.length);
        if (temperature != 1.0) {
            int i2 = 0;
            while (i2 < logprobs.w.length) {
                int n = i2++;
                logprobs.w[n] = logprobs.w[n] / temperature;
            }
        }
        double maxval = Double.NEGATIVE_INFINITY;
        for (int i3 = 0; i3 < logprobs.w.length; ++i3) {
            if (!(logprobs.w[i3] > maxval)) continue;
            maxval = logprobs.w[i3];
        }
        double sum = 0.0;
        for (i = 0; i < logprobs.w.length; ++i) {
            probs.w[i] = Math.exp(logprobs.w[i] - maxval);
            sum += probs.w[i];
        }
        i = 0;
        while (i < probs.w.length) {
            int n = i++;
            probs.w[n] = probs.w[n] / sum;
        }
        return probs;
    }

    private static int getTargetIndex(Matrix targetOutput) throws Exception {
        for (int i = 0; i < targetOutput.w.length; ++i) {
            if (targetOutput.w[i] != 1.0) continue;
            return i;
        }
        throw new Exception("no target index selected");
    }
}

