package org.ea.javacnn.layers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.ea.javacnn.data.BackPropResult;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.data.OutputDefinition;

/* loaded from: input_file:org/ea/javacnn/layers/MaxoutLayer.class */
public class MaxoutLayer implements Layer, Serializable {
    private int out_depth;
    private int out_sx;
    private int out_sy;
    private DataBlock in_act;
    private DataBlock out_act;
    private final int group_size = 2;
    private int[] switches;

    public MaxoutLayer(OutputDefinition outputDefinition) {
        this.out_sx = outputDefinition.getOutX();
        this.out_sy = outputDefinition.getOutY();
        int depth = outputDefinition.getDepth();
        getClass();
        this.out_depth = (int) Math.floor(depth / 2);
        this.switches = new int[this.out_sx * this.out_sy * this.out_depth];
        Arrays.fill(this.switches, 0);
    }

    @Override // org.ea.javacnn.layers.Layer
    public DataBlock forward(DataBlock dataBlock, boolean z) {
        this.in_act = dataBlock;
        int i = this.out_depth;
        DataBlock dataBlock2 = new DataBlock(this.out_sx, this.out_sy, this.out_depth, 0.0d);
        if (this.out_sx == 1 && this.out_sy == 1) {
            for (int i2 = 0; i2 < i; i2++) {
                getClass();
                int i3 = i2 * 2;
                double weight = dataBlock.getWeight(i3);
                int i4 = 0;
                int i5 = 1;
                while (true) {
                    int i6 = i5;
                    getClass();
                    if (i6 < 2) {
                        double weight2 = dataBlock.getWeight(i3 + i5);
                        if (weight2 > weight) {
                            weight = weight2;
                            i4 = i5;
                        }
                        i5++;
                    }
                }
                dataBlock2.setWeight(i2, weight);
                this.switches[i2] = i3 + i4;
            }
        } else {
            int i7 = 0;
            for (int i8 = 0; i8 < dataBlock.getSX(); i8++) {
                for (int i9 = 0; i9 < dataBlock.getSY(); i9++) {
                    for (int i10 = 0; i10 < i; i10++) {
                        getClass();
                        int i11 = i10 * 2;
                        double weight3 = dataBlock.getWeight(i8, i9, i11);
                        int i12 = 0;
                        int i13 = 1;
                        while (true) {
                            int i14 = i13;
                            getClass();
                            if (i14 < 2) {
                                double weight4 = dataBlock.getWeight(i8, i9, i11 + i13);
                                if (weight4 > weight3) {
                                    weight3 = weight4;
                                    i12 = i13;
                                }
                                i13++;
                            }
                        }
                        dataBlock2.setWeight(i8, i9, i10, weight3);
                        this.switches[i7] = i11 + i12;
                        i7++;
                    }
                }
            }
        }
        this.out_act = dataBlock2;
        return this.out_act;
    }

    @Override // org.ea.javacnn.layers.Layer
    public void backward() {
        DataBlock dataBlock = this.in_act;
        DataBlock dataBlock2 = this.out_act;
        int i = this.out_depth;
        dataBlock.clearGradient();
        if (this.out_sx == 1 && this.out_sy == 1) {
            for (int i2 = 0; i2 < i; i2++) {
                dataBlock.setGradient(this.switches[i2], dataBlock2.getGradient(i2));
            }
            return;
        }
        int i3 = 0;
        for (int i4 = 0; i4 < dataBlock2.getSX(); i4++) {
            for (int i5 = 0; i5 < dataBlock2.getSY(); i5++) {
                for (int i6 = 0; i6 < i; i6++) {
                    dataBlock.setGradient(i4, i5, this.switches[i3], dataBlock2.getGradient(i4, i5, i6));
                    i3++;
                }
            }
        }
    }

    @Override // org.ea.javacnn.layers.Layer
    public List<BackPropResult> getBackPropagationResult() {
        return new ArrayList();
    }
}
