/*
 * Decompiled with CFR 0.152.
 */
package org.ea.javacnn.layers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.ea.javacnn.data.BackPropResult;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.layers.Layer;

public class LocalResponseNormalizationLayer
implements Layer,
Serializable {
    private final double k = 2.0;
    private final double n = 5.0;
    private final double alpha = 1.0E-4;
    private final double beta = 0.75;
    private DataBlock in_act;
    private DataBlock out_act;
    private DataBlock S_cache_;

    public LocalResponseNormalizationLayer() {
        if (this.n % 2.0 == 0.0) {
            System.out.println("WARNING n should be odd for LRN layer");
        }
    }

    @Override
    public DataBlock forward(DataBlock db, boolean training) {
        this.in_act = db;
        DataBlock A = db.cloneAndZero();
        this.S_cache_ = db.cloneAndZero();
        double n2 = Math.floor(this.n / 2.0);
        for (int x = 0; x < db.getSX(); ++x) {
            for (int y = 0; y < db.getSY(); ++y) {
                for (int i = 0; i < db.getDepth(); ++i) {
                    double ai = db.getWeight(x, y, i);
                    double den = 0.0;
                    int j = (int)Math.max(0.0, (double)i - n2);
                    while ((double)j <= Math.min((double)i + n2, (double)(db.getDepth() - 1))) {
                        double aa = db.getWeight(x, y, j);
                        den += aa * aa;
                        ++j;
                    }
                    den *= this.alpha / this.n;
                    this.S_cache_.setWeight(x, y, i, den += this.k);
                    den = Math.pow(den, this.beta);
                    A.setWeight(x, y, i, ai / den);
                }
            }
        }
        this.out_act = A;
        return this.out_act;
    }

    @Override
    public void backward() {
        DataBlock V = this.in_act;
        V.clearGradient();
        int n2 = (int)Math.floor(this.n / 2.0);
        for (int x = 0; x < V.getSX(); ++x) {
            for (int y = 0; y < V.getSY(); ++y) {
                for (int i = 0; i < V.getDepth(); ++i) {
                    double chain_grad = this.out_act.getGradient(x, y, i);
                    double S = this.S_cache_.getWeight(x, y, i);
                    double SB = Math.pow(S, this.beta);
                    double SB2 = SB * SB;
                    for (int j = Math.max(0, i - n2); j <= Math.min(i + n2, V.getDepth() - 1); ++j) {
                        double aj = V.getWeight(x, y, j);
                        double g = -aj * this.beta * Math.pow(S, this.beta - 1.0) * this.alpha / this.n * 2.0 * aj;
                        if (j == i) {
                            g += SB;
                        }
                        g /= SB2;
                        V.addGradient(x, y, j, g *= chain_grad);
                    }
                }
            }
        }
    }

    @Override
    public List<BackPropResult> getBackPropagationResult() {
        return new ArrayList<BackPropResult>();
    }
}

