/*
 * Decompiled with CFR 0.152.
 */
package org.ea.javacnn.losslayers;

import java.util.Arrays;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.data.OutputDefinition;
import org.ea.javacnn.losslayers.LossLayer;

public class SoftMaxLayer
extends LossLayer {
    private int num_inputs;
    private int out_depth;
    private int out_sx;
    private int out_sy;
    private double[] es;

    public SoftMaxLayer(OutputDefinition def) {
        super(def);
        this.out_depth = this.num_inputs = def.getOutY() * def.getOutX() * def.getDepth();
        this.out_sx = 1;
        this.out_sy = 1;
        def.setOutX(this.out_sx);
        def.setOutY(this.out_sy);
        def.setDepth(this.out_depth);
    }

    @Override
    public DataBlock forward(DataBlock db, boolean training) {
        int i;
        this.in_act = db;
        DataBlock A = new DataBlock(1, 1, this.out_depth, 0.0);
        double[] as = db.getWeights();
        double amax = db.getWeight(0);
        for (int i2 = 1; i2 < this.out_depth; ++i2) {
            if (!(as[i2] > amax)) continue;
            amax = as[i2];
        }
        double[] es = new double[this.out_depth];
        Arrays.fill(es, 0.0);
        double esum = 0.0;
        for (i = 0; i < this.out_depth; ++i) {
            double e = Math.exp(as[i] - amax);
            esum += e;
            es[i] = e;
        }
        for (i = 0; i < this.out_depth; ++i) {
            int n = i;
            es[n] = es[n] / esum;
            A.setWeight(i, es[i]);
        }
        this.es = es;
        this.out_act = A;
        return this.out_act;
    }

    @Override
    public double backward(int y) {
        DataBlock x = this.in_act;
        x.clearGradient();
        for (int i = 0; i < this.out_depth; ++i) {
            double indicator = i == y ? 1.0 : 0.0;
            double mul = -(indicator - this.es[i]);
            x.setGradient(i, mul);
        }
        return -Math.log(this.es[y]);
    }
}

