package org.ea.javacnn.layers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.ea.javacnn.data.BackPropResult;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.data.OutputDefinition;

/* loaded from: input_file:org/ea/javacnn/layers/ConvolutionLayer.class */
public class ConvolutionLayer implements Layer, Serializable {
    private DataBlock in_act;
    private DataBlock out_act;
    private int out_depth;
    private int out_sx;
    private int out_sy;
    private int in_depth;
    private int in_sx;
    private int in_sy;
    private int sx;
    private int sy;
    private int stride;
    private int padding;
    private DataBlock biases;
    private double l1_decay_mul = 0.0d;
    private double l2_decay_mul = 1.0d;
    private final float BIAS_PREF = 0.1f;
    private List<DataBlock> filters = new ArrayList();

    public ConvolutionLayer(OutputDefinition outputDefinition, int i, int i2, int i3, int i4) {
        this.out_depth = i2;
        this.sx = i;
        this.in_depth = outputDefinition.getDepth();
        this.in_sx = outputDefinition.getOutX();
        this.in_sy = outputDefinition.getOutY();
        this.sy = this.sx;
        this.stride = i3;
        this.padding = i4;
        this.out_sx = (int) Math.floor((((this.in_sx + (this.padding * 2)) - this.sx) / this.stride) + 1);
        this.out_sy = (int) Math.floor((((this.in_sy + (this.padding * 2)) - this.sy) / this.stride) + 1);
        for (int i5 = 0; i5 < this.out_depth; i5++) {
            this.filters.add(new DataBlock(this.sx, this.sy, this.in_depth));
        }
        this.biases = new DataBlock(1, 1, this.out_depth, 0.10000000149011612d);
        outputDefinition.setOutX(this.out_sx);
        outputDefinition.setOutY(this.out_sy);
        outputDefinition.setDepth(this.out_depth);
    }

    @Override // org.ea.javacnn.layers.Layer
    public DataBlock forward(DataBlock dataBlock, boolean z) {
        this.in_act = dataBlock;
        DataBlock dataBlock2 = new DataBlock(this.out_sx, this.out_sy, this.out_depth, 0.0d);
        int i = this.in_sx;
        int i2 = this.in_sy;
        int i3 = this.stride;
        for (int i4 = 0; i4 < this.out_depth; i4++) {
            DataBlock dataBlock3 = this.filters.get(i4);
            int i5 = -this.padding;
            for (int i6 = 0; i6 < this.out_sy; i6++) {
                int i7 = -this.padding;
                for (int i8 = 0; i8 < this.out_sx; i8++) {
                    double d = 0.0d;
                    for (int i9 = 0; i9 < dataBlock3.getSY(); i9++) {
                        int i10 = i5 + i9;
                        for (int i11 = 0; i11 < dataBlock3.getSX(); i11++) {
                            int i12 = i7 + i11;
                            if (i10 >= 0 && i10 < i2 && i12 >= 0 && i12 < i) {
                                for (int i13 = 0; i13 < dataBlock3.getDepth(); i13++) {
                                    d += dataBlock3.getWeight(i11, i9, i13) * dataBlock.getWeight(i12, i10, i13);
                                }
                            }
                        }
                    }
                    dataBlock2.setWeight(i8, i6, i4, d + this.biases.getWeight(i4));
                    i7 += i3;
                }
                i5 += i3;
            }
        }
        this.out_act = dataBlock2;
        return dataBlock2;
    }

    @Override // org.ea.javacnn.layers.Layer
    public void backward() {
        DataBlock dataBlock = this.in_act;
        dataBlock.clearGradient();
        int sx = dataBlock.getSX();
        int sy = dataBlock.getSY();
        int i = this.stride;
        for (int i2 = 0; i2 < this.out_depth; i2++) {
            DataBlock dataBlock2 = this.filters.get(i2);
            int i3 = -this.padding;
            for (int i4 = 0; i4 < this.out_sy; i4++) {
                int i5 = -this.padding;
                for (int i6 = 0; i6 < this.out_sx; i6++) {
                    double gradient = this.out_act.getGradient(i6, i4, i2);
                    for (int i7 = 0; i7 < dataBlock2.getSY(); i7++) {
                        int i8 = i3 + i7;
                        for (int i9 = 0; i9 < dataBlock2.getSX(); i9++) {
                            int i10 = i5 + i9;
                            if (i8 >= 0 && i8 < sy && i10 >= 0 && i10 < sx) {
                                for (int i11 = 0; i11 < dataBlock2.getDepth(); i11++) {
                                    int depth = (((sx * i8) + i10) * dataBlock.getDepth()) + i11;
                                    int sy2 = (((dataBlock2.getSY() * i7) + i9) * dataBlock2.getDepth()) + i11;
                                    dataBlock2.addGradient(sy2, dataBlock.getWeight(depth) * gradient);
                                    dataBlock.addGradient(depth, dataBlock2.getWeight(sy2) * gradient);
                                }
                            }
                        }
                    }
                    this.biases.addGradient(i2, gradient);
                    i5 += i;
                }
                i3 += i;
            }
        }
    }

    @Override // org.ea.javacnn.layers.Layer
    public List<BackPropResult> getBackPropagationResult() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.out_depth; i++) {
            arrayList.add(new BackPropResult(this.filters.get(i).getWeights(), this.filters.get(i).getGradients(), this.l2_decay_mul, this.l1_decay_mul));
        }
        arrayList.add(new BackPropResult(this.biases.getWeights(), this.biases.getGradients(), 0.0d, 0.0d));
        return arrayList;
    }
}
