package org.ea.javacnn.layers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.ea.javacnn.data.BackPropResult;
import org.ea.javacnn.data.DataBlock;

/* loaded from: input_file:org/ea/javacnn/layers/LocalResponseNormalizationLayer.class */
public class LocalResponseNormalizationLayer implements Layer, Serializable {
    private final double k = 2.0d;
    private final double n = 5.0d;
    private final double alpha = 1.0E-4d;
    private final double beta = 0.75d;
    private DataBlock in_act;
    private DataBlock out_act;
    private DataBlock S_cache_;

    public LocalResponseNormalizationLayer() {
        getClass();
        if (5.0d % 2.0d == 0.0d) {
            System.out.println("WARNING n should be odd for LRN layer");
        }
    }

    @Override // org.ea.javacnn.layers.Layer
    public DataBlock forward(DataBlock dataBlock, boolean z) {
        this.in_act = dataBlock;
        DataBlock cloneAndZero = dataBlock.cloneAndZero();
        this.S_cache_ = dataBlock.cloneAndZero();
        getClass();
        double floor = Math.floor(5.0d / 2.0d);
        for (int i = 0; i < dataBlock.getSX(); i++) {
            for (int i2 = 0; i2 < dataBlock.getSY(); i2++) {
                for (int i3 = 0; i3 < dataBlock.getDepth(); i3++) {
                    double weight = dataBlock.getWeight(i, i2, i3);
                    double d = 0.0d;
                    for (int max = (int) Math.max(0.0d, i3 - floor); max <= Math.min(i3 + floor, dataBlock.getDepth() - 1); max++) {
                        double weight2 = dataBlock.getWeight(i, i2, max);
                        d += weight2 * weight2;
                    }
                    getClass();
                    getClass();
                    getClass();
                    double d2 = (d * (1.0E-4d / 5.0d)) + 2.0d;
                    this.S_cache_.setWeight(i, i2, i3, d2);
                    getClass();
                    cloneAndZero.setWeight(i, i2, i3, weight / Math.pow(d2, 0.75d));
                }
            }
        }
        this.out_act = cloneAndZero;
        return this.out_act;
    }

    @Override // org.ea.javacnn.layers.Layer
    public void backward() {
        DataBlock dataBlock = this.in_act;
        dataBlock.clearGradient();
        getClass();
        int floor = (int) Math.floor(5.0d / 2.0d);
        for (int i = 0; i < dataBlock.getSX(); i++) {
            for (int i2 = 0; i2 < dataBlock.getSY(); i2++) {
                for (int i3 = 0; i3 < dataBlock.getDepth(); i3++) {
                    double gradient = this.out_act.getGradient(i, i2, i3);
                    double weight = this.S_cache_.getWeight(i, i2, i3);
                    getClass();
                    double pow = Math.pow(weight, 0.75d);
                    double d = pow * pow;
                    for (int max = Math.max(0, i3 - floor); max <= Math.min(i3 + floor, dataBlock.getDepth() - 1); max++) {
                        double weight2 = dataBlock.getWeight(i, i2, max);
                        getClass();
                        getClass();
                        double pow2 = (-weight2) * 0.75d * Math.pow(weight, 0.75d - 1.0d);
                        getClass();
                        getClass();
                        double d2 = ((pow2 * 1.0E-4d) / 5.0d) * 2.0d * weight2;
                        if (max == i3) {
                            d2 += pow;
                        }
                        dataBlock.addGradient(i, i2, max, (d2 / d) * gradient);
                    }
                }
            }
        }
    }

    @Override // org.ea.javacnn.layers.Layer
    public List<BackPropResult> getBackPropagationResult() {
        return new ArrayList();
    }
}
