package recunn.loss;

import java.util.ArrayList;
import java.util.List;
import recunn.autodiff.Graph;
import recunn.datastructs.DataSequence;
import recunn.datastructs.DataStep;
import recunn.matrix.Matrix;
import recunn.model.Model;
import recunn.util.Util;

/* loaded from: input_file:recunn/loss/LossSoftmax.class */
public class LossSoftmax implements Loss {
    private static final long serialVersionUID = 1;

    @Override // recunn.loss.Loss
    public void backward(Matrix matrix, Matrix matrix2) throws Exception {
        int targetIndex = getTargetIndex(matrix2);
        Matrix softmaxProbs = getSoftmaxProbs(matrix, 1.0d);
        for (int i = 0; i < softmaxProbs.w.length; i++) {
            matrix.dw[i] = softmaxProbs.w[i];
        }
        double[] dArr = matrix.dw;
        dArr[targetIndex] = dArr[targetIndex] - 1.0d;
    }

    @Override // recunn.loss.Loss
    public double measure(Matrix matrix, Matrix matrix2) throws Exception {
        return -Math.log(getSoftmaxProbs(matrix, 1.0d).w[getTargetIndex(matrix2)]);
    }

    public static double calculateMedianPerplexity(Model model, List<DataSequence> list) throws Exception {
        ArrayList arrayList = new ArrayList();
        for (DataSequence dataSequence : list) {
            double d = 0.0d;
            double d2 = 0.0d;
            Graph graph = new Graph(false);
            model.resetState();
            for (DataStep dataStep : dataSequence.steps) {
                d2 += -(Math.log(getSoftmaxProbs(model.forward(dataStep.input, graph), 1.0d).w[getTargetIndex(dataStep.targetOutput)]) / Math.log(2.0d));
                d += 1.0d;
            }
            arrayList.add(Double.valueOf(Math.pow(2.0d, d2 / ((d - 1.0d) - 1.0d))));
        }
        return Util.median(arrayList);
    }

    public static Matrix getSoftmaxProbs(Matrix matrix, double d) throws Exception {
        Matrix matrix2 = new Matrix(matrix.w.length);
        if (d != 1.0d) {
            for (int i = 0; i < matrix.w.length; i++) {
                double[] dArr = matrix.w;
                int i2 = i;
                dArr[i2] = dArr[i2] / d;
            }
        }
        double d2 = Double.NEGATIVE_INFINITY;
        for (int i3 = 0; i3 < matrix.w.length; i3++) {
            if (matrix.w[i3] > d2) {
                d2 = matrix.w[i3];
            }
        }
        double d3 = 0.0d;
        for (int i4 = 0; i4 < matrix.w.length; i4++) {
            matrix2.w[i4] = Math.exp(matrix.w[i4] - d2);
            d3 += matrix2.w[i4];
        }
        for (int i5 = 0; i5 < matrix2.w.length; i5++) {
            double[] dArr2 = matrix2.w;
            int i6 = i5;
            dArr2[i6] = dArr2[i6] / d3;
        }
        return matrix2;
    }

    private static int getTargetIndex(Matrix matrix) throws Exception {
        for (int i = 0; i < matrix.w.length; i++) {
            if (matrix.w[i] == 1.0d) {
                return i;
            }
        }
        throw new Exception("no target index selected");
    }
}
