package org.ea.javacnn.losslayers;

import java.util.Arrays;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.data.OutputDefinition;

/* loaded from: input_file:org/ea/javacnn/losslayers/SoftMaxLayer.class */
public class SoftMaxLayer extends LossLayer {
    private int num_inputs;
    private int out_depth;
    private int out_sx;
    private int out_sy;
    private double[] es;

    public SoftMaxLayer(OutputDefinition outputDefinition) {
        super(outputDefinition);
        this.num_inputs = outputDefinition.getOutY() * outputDefinition.getOutX() * outputDefinition.getDepth();
        this.out_depth = this.num_inputs;
        this.out_sx = 1;
        this.out_sy = 1;
        outputDefinition.setOutX(this.out_sx);
        outputDefinition.setOutY(this.out_sy);
        outputDefinition.setDepth(this.out_depth);
    }

    @Override // org.ea.javacnn.layers.Layer
    public DataBlock forward(DataBlock dataBlock, boolean z) {
        this.in_act = dataBlock;
        DataBlock dataBlock2 = new DataBlock(1, 1, this.out_depth, 0.0d);
        double[] weights = dataBlock.getWeights();
        double weight = dataBlock.getWeight(0);
        for (int i = 1; i < this.out_depth; i++) {
            if (weights[i] > weight) {
                weight = weights[i];
            }
        }
        double[] dArr = new double[this.out_depth];
        Arrays.fill(dArr, 0.0d);
        double d = 0.0d;
        for (int i2 = 0; i2 < this.out_depth; i2++) {
            double exp = Math.exp(weights[i2] - weight);
            d += exp;
            dArr[i2] = exp;
        }
        for (int i3 = 0; i3 < this.out_depth; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] / d;
            dataBlock2.setWeight(i3, dArr[i3]);
        }
        this.es = dArr;
        this.out_act = dataBlock2;
        return this.out_act;
    }

    @Override // org.ea.javacnn.losslayers.LossLayer
    public double backward(int i) {
        DataBlock dataBlock = this.in_act;
        dataBlock.clearGradient();
        int i2 = 0;
        while (i2 < this.out_depth) {
            dataBlock.setGradient(i2, -((i2 == i ? 1.0d : 0.0d) - this.es[i2]));
            i2++;
        }
        return -Math.log(this.es[i]);
    }
}
