package org.ea.javacnn.layers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.ea.javacnn.data.BackPropResult;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.data.OutputDefinition;

/* loaded from: input_file:org/ea/javacnn/layers/PoolingLayer.class */
public class PoolingLayer implements Layer, Serializable {
    private int in_depth;
    private int in_sx;
    private int in_sy;
    private int out_depth;
    private int out_sx;
    private int out_sy;
    private int sx;
    private int sy;
    private int stride;
    private int padding;
    private int[] switchx;
    private int[] switchy;
    private DataBlock in_act;
    private DataBlock out_act;

    public PoolingLayer(OutputDefinition outputDefinition, int i, int i2, int i3) {
        this.sx = i;
        this.stride = i2;
        this.in_depth = outputDefinition.getDepth();
        this.in_sx = outputDefinition.getOutX();
        this.in_sy = outputDefinition.getOutY();
        this.sy = this.sx;
        this.padding = i3;
        this.out_depth = this.in_depth;
        this.out_sx = (int) Math.floor((((this.in_sx + (this.padding * 2)) - this.sx) / this.stride) + 1);
        this.out_sy = (int) Math.floor((((this.in_sy + (this.padding * 2)) - this.sy) / this.stride) + 1);
        this.switchx = new int[this.out_sx * this.out_sy * this.out_depth];
        this.switchy = new int[this.out_sx * this.out_sy * this.out_depth];
        Arrays.fill(this.switchx, 0);
        Arrays.fill(this.switchy, 0);
        outputDefinition.setOutX(this.out_sx);
        outputDefinition.setOutY(this.out_sy);
        outputDefinition.setDepth(this.out_depth);
    }

    @Override // org.ea.javacnn.layers.Layer
    public DataBlock forward(DataBlock dataBlock, boolean z) {
        this.in_act = dataBlock;
        DataBlock dataBlock2 = new DataBlock(this.out_sx, this.out_sy, this.out_depth, 0.0d);
        int i = 0;
        for (int i2 = 0; i2 < this.out_depth; i2++) {
            int i3 = -this.padding;
            for (int i4 = 0; i4 < this.out_sx; i4++) {
                int i5 = -this.padding;
                for (int i6 = 0; i6 < this.out_sy; i6++) {
                    double d = -99999.0d;
                    int i7 = -1;
                    int i8 = -1;
                    for (int i9 = 0; i9 < this.sx; i9++) {
                        for (int i10 = 0; i10 < this.sy; i10++) {
                            int i11 = i5 + i10;
                            int i12 = i3 + i9;
                            if (i11 >= 0 && i11 < dataBlock.getSY() && i12 >= 0 && i12 < dataBlock.getSX()) {
                                double weight = dataBlock.getWeight(i12, i11, i2);
                                if (weight > d) {
                                    d = weight;
                                    i7 = i12;
                                    i8 = i11;
                                }
                            }
                        }
                    }
                    this.switchx[i] = i7;
                    this.switchy[i] = i8;
                    i++;
                    dataBlock2.setWeight(i4, i6, i2, d);
                    i5 += this.stride;
                }
                i3 += this.stride;
            }
        }
        this.out_act = dataBlock2;
        return this.out_act;
    }

    @Override // org.ea.javacnn.layers.Layer
    public void backward() {
        DataBlock dataBlock = this.in_act;
        dataBlock.clearGradient();
        int i = 0;
        for (int i2 = 0; i2 < this.out_depth; i2++) {
            int i3 = -this.padding;
            for (int i4 = 0; i4 < this.out_sx; i4++) {
                int i5 = -this.padding;
                for (int i6 = 0; i6 < this.out_sy; i6++) {
                    dataBlock.addGradient(this.switchx[i], this.switchy[i], i2, this.out_act.getGradient(i4, i6, i2));
                    i++;
                    i5 += this.stride;
                }
                i3 += this.stride;
            }
        }
    }

    @Override // org.ea.javacnn.layers.Layer
    public List<BackPropResult> getBackPropagationResult() {
        return new ArrayList();
    }
}
