/*
 * Decompiled with CFR 0.152.
 */
package org.ea.javacnn.layers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.ea.javacnn.data.BackPropResult;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.data.OutputDefinition;
import org.ea.javacnn.layers.Layer;

public class PoolingLayer
implements Layer,
Serializable {
    private int in_depth;
    private int in_sx;
    private int in_sy;
    private int out_depth;
    private int out_sx;
    private int out_sy;
    private int sx;
    private int sy;
    private int stride;
    private int padding;
    private int[] switchx;
    private int[] switchy;
    private DataBlock in_act;
    private DataBlock out_act;

    public PoolingLayer(OutputDefinition def, int sx, int stride, int padding) {
        this.sx = sx;
        this.stride = stride;
        this.in_depth = def.getDepth();
        this.in_sx = def.getOutX();
        this.in_sy = def.getOutY();
        this.sy = this.sx;
        this.padding = padding;
        this.out_depth = this.in_depth;
        this.out_sx = (int)Math.floor((this.in_sx + this.padding * 2 - this.sx) / this.stride + 1);
        this.out_sy = (int)Math.floor((this.in_sy + this.padding * 2 - this.sy) / this.stride + 1);
        this.switchx = new int[this.out_sx * this.out_sy * this.out_depth];
        this.switchy = new int[this.out_sx * this.out_sy * this.out_depth];
        Arrays.fill(this.switchx, 0);
        Arrays.fill(this.switchy, 0);
        def.setOutX(this.out_sx);
        def.setOutY(this.out_sy);
        def.setDepth(this.out_depth);
    }

    @Override
    public DataBlock forward(DataBlock db, boolean training) {
        this.in_act = db;
        DataBlock A = new DataBlock(this.out_sx, this.out_sy, this.out_depth, 0.0);
        int n = 0;
        for (int d = 0; d < this.out_depth; ++d) {
            int x = -this.padding;
            for (int ax = 0; ax < this.out_sx; ++ax) {
                int y = -this.padding;
                for (int ay = 0; ay < this.out_sy; ++ay) {
                    double a = -99999.0;
                    int winx = -1;
                    int winy = -1;
                    for (int fx = 0; fx < this.sx; ++fx) {
                        for (int fy = 0; fy < this.sy; ++fy) {
                            double v;
                            int oy = y + fy;
                            int ox = x + fx;
                            if (oy < 0 || oy >= db.getSY() || ox < 0 || ox >= db.getSX() || !((v = db.getWeight(ox, oy, d)) > a)) continue;
                            a = v;
                            winx = ox;
                            winy = oy;
                        }
                    }
                    this.switchx[n] = winx;
                    this.switchy[n] = winy;
                    ++n;
                    A.setWeight(ax, ay, d, a);
                    y += this.stride;
                }
                x += this.stride;
            }
        }
        this.out_act = A;
        return this.out_act;
    }

    @Override
    public void backward() {
        DataBlock V = this.in_act;
        V.clearGradient();
        int n = 0;
        for (int d = 0; d < this.out_depth; ++d) {
            int x = -this.padding;
            for (int ax = 0; ax < this.out_sx; ++ax) {
                int y = -this.padding;
                for (int ay = 0; ay < this.out_sy; ++ay) {
                    double chain_grad = this.out_act.getGradient(ax, ay, d);
                    V.addGradient(this.switchx[n], this.switchy[n], d, chain_grad);
                    ++n;
                    y += this.stride;
                }
                x += this.stride;
            }
        }
    }

    @Override
    public List<BackPropResult> getBackPropagationResult() {
        return new ArrayList<BackPropResult>();
    }
}

