package recunn.loss;

import recunn.matrix.Matrix;

/* loaded from: input_file:recunn/loss/LossArgMax.class */
public class LossArgMax implements Loss {
    private static final long serialVersionUID = 1;

    @Override // recunn.loss.Loss
    public void backward(Matrix matrix, Matrix matrix2) throws Exception {
        throw new Exception("not implemented");
    }

    @Override // recunn.loss.Loss
    public double measure(Matrix matrix, Matrix matrix2) throws Exception {
        if (matrix.w.length != matrix2.w.length) {
            throw new Exception("mismatch");
        }
        double d = Double.NEGATIVE_INFINITY;
        double d2 = Double.NEGATIVE_INFINITY;
        int i = -1;
        int i2 = -1;
        for (int i3 = 0; i3 < matrix.w.length; i3++) {
            if (matrix.w[i3] > d) {
                d = matrix.w[i3];
                i = i3;
            }
            if (matrix2.w[i3] > d2) {
                d2 = matrix2.w[i3];
                i2 = i3;
            }
        }
        return i == i2 ? 0.0d : 1.0d;
    }
}
