/*
 * Decompiled with CFR 0.152.
 */
package org.ea.javacnn.layers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.ea.javacnn.data.BackPropResult;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.data.OutputDefinition;
import org.ea.javacnn.layers.Layer;

public class ConvolutionLayer
implements Layer,
Serializable {
    private double l1_decay_mul = 0.0;
    private double l2_decay_mul = 1.0;
    private DataBlock in_act;
    private DataBlock out_act;
    private final float BIAS_PREF = 0.1f;
    private int out_depth;
    private int out_sx;
    private int out_sy;
    private int in_depth;
    private int in_sx;
    private int in_sy;
    private int sx;
    private int sy;
    private int stride;
    private int padding;
    private List<DataBlock> filters;
    private DataBlock biases;

    public ConvolutionLayer(OutputDefinition def, int sx, int filters, int stride, int padding) {
        this.out_depth = filters;
        this.sx = sx;
        this.in_depth = def.getDepth();
        this.in_sx = def.getOutX();
        this.in_sy = def.getOutY();
        this.sy = this.sx;
        this.stride = stride;
        this.padding = padding;
        this.out_sx = (int)Math.floor((this.in_sx + this.padding * 2 - this.sx) / this.stride + 1);
        this.out_sy = (int)Math.floor((this.in_sy + this.padding * 2 - this.sy) / this.stride + 1);
        this.filters = new ArrayList<DataBlock>();
        for (int i = 0; i < this.out_depth; ++i) {
            this.filters.add(new DataBlock(this.sx, this.sy, this.in_depth));
        }
        this.biases = new DataBlock(1, 1, this.out_depth, 0.1f);
        def.setOutX(this.out_sx);
        def.setOutY(this.out_sy);
        def.setDepth(this.out_depth);
    }

    @Override
    public DataBlock forward(DataBlock db, boolean training) {
        this.in_act = db;
        DataBlock A = new DataBlock(this.out_sx, this.out_sy, this.out_depth, 0.0);
        int V_sx = this.in_sx;
        int V_sy = this.in_sy;
        int xy_stride = this.stride;
        for (int d = 0; d < this.out_depth; ++d) {
            DataBlock f = this.filters.get(d);
            int y = -this.padding;
            for (int ay = 0; ay < this.out_sy; ++ay) {
                int x = -this.padding;
                for (int ax = 0; ax < this.out_sx; ++ax) {
                    double a = 0.0;
                    for (int fy = 0; fy < f.getSY(); ++fy) {
                        int oy = y + fy;
                        for (int fx = 0; fx < f.getSX(); ++fx) {
                            int ox = x + fx;
                            if (oy < 0 || oy >= V_sy || ox < 0 || ox >= V_sx) continue;
                            for (int fd = 0; fd < f.getDepth(); ++fd) {
                                a += f.getWeight(fx, fy, fd) * db.getWeight(ox, oy, fd);
                            }
                        }
                    }
                    A.setWeight(ax, ay, d, a += this.biases.getWeight(d));
                    x += xy_stride;
                }
                y += xy_stride;
            }
        }
        this.out_act = A;
        return A;
    }

    @Override
    public void backward() {
        DataBlock db = this.in_act;
        db.clearGradient();
        int V_sx = db.getSX();
        int V_sy = db.getSY();
        int xy_stride = this.stride;
        for (int d = 0; d < this.out_depth; ++d) {
            DataBlock f = this.filters.get(d);
            int y = -this.padding;
            for (int ay = 0; ay < this.out_sy; ++ay) {
                int x = -this.padding;
                for (int ax = 0; ax < this.out_sx; ++ax) {
                    double chain_grad = this.out_act.getGradient(ax, ay, d);
                    for (int fy = 0; fy < f.getSY(); ++fy) {
                        int oy = y + fy;
                        for (int fx = 0; fx < f.getSX(); ++fx) {
                            int ox = x + fx;
                            if (oy < 0 || oy >= V_sy || ox < 0 || ox >= V_sx) continue;
                            for (int fd = 0; fd < f.getDepth(); ++fd) {
                                int ix1 = (V_sx * oy + ox) * db.getDepth() + fd;
                                int ix2 = (f.getSY() * fy + fx) * f.getDepth() + fd;
                                f.addGradient(ix2, db.getWeight(ix1) * chain_grad);
                                db.addGradient(ix1, f.getWeight(ix2) * chain_grad);
                            }
                        }
                    }
                    this.biases.addGradient(d, chain_grad);
                    x += xy_stride;
                }
                y += xy_stride;
            }
        }
    }

    @Override
    public List<BackPropResult> getBackPropagationResult() {
        ArrayList<BackPropResult> results = new ArrayList<BackPropResult>();
        for (int i = 0; i < this.out_depth; ++i) {
            results.add(new BackPropResult(this.filters.get(i).getWeights(), this.filters.get(i).getGradients(), this.l2_decay_mul, this.l1_decay_mul));
        }
        results.add(new BackPropResult(this.biases.getWeights(), this.biases.getGradients(), 0.0, 0.0));
        return results;
    }
}

