/*
 * Decompiled with CFR 0.152.
 */
package org.ea.javacnn.layers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.ea.javacnn.data.BackPropResult;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.data.OutputDefinition;
import org.ea.javacnn.layers.Layer;

public class MaxoutLayer
implements Layer,
Serializable {
    private int out_depth;
    private int out_sx;
    private int out_sy;
    private DataBlock in_act;
    private DataBlock out_act;
    private final int group_size = 2;
    private int[] switches;

    public MaxoutLayer(OutputDefinition def) {
        this.out_sx = def.getOutX();
        this.out_sy = def.getOutY();
        this.out_depth = (int)Math.floor(def.getDepth() / this.group_size);
        this.switches = new int[this.out_sx * this.out_sy * this.out_depth];
        Arrays.fill(this.switches, 0);
    }

    @Override
    public DataBlock forward(DataBlock db, boolean training) {
        this.in_act = db;
        int N = this.out_depth;
        DataBlock V2 = new DataBlock(this.out_sx, this.out_sy, this.out_depth, 0.0);
        if (this.out_sx == 1 && this.out_sy == 1) {
            for (int i = 0; i < N; ++i) {
                int ix = i * this.group_size;
                double a = db.getWeight(ix);
                int ai = 0;
                for (int j = 1; j < this.group_size; ++j) {
                    double a2 = db.getWeight(ix + j);
                    if (!(a2 > a)) continue;
                    a = a2;
                    ai = j;
                }
                V2.setWeight(i, a);
                this.switches[i] = ix + ai;
            }
        } else {
            int n = 0;
            for (int x = 0; x < db.getSX(); ++x) {
                for (int y = 0; y < db.getSY(); ++y) {
                    for (int i = 0; i < N; ++i) {
                        int ix = i * this.group_size;
                        double a = db.getWeight(x, y, ix);
                        int ai = 0;
                        for (int j = 1; j < this.group_size; ++j) {
                            double a2 = db.getWeight(x, y, ix + j);
                            if (!(a2 > a)) continue;
                            a = a2;
                            ai = j;
                        }
                        V2.setWeight(x, y, i, a);
                        this.switches[n] = ix + ai;
                        ++n;
                    }
                }
            }
        }
        this.out_act = V2;
        return this.out_act;
    }

    @Override
    public void backward() {
        DataBlock V = this.in_act;
        DataBlock V2 = this.out_act;
        int N = this.out_depth;
        V.clearGradient();
        if (this.out_sx == 1 && this.out_sy == 1) {
            for (int i = 0; i < N; ++i) {
                double chain_grad = V2.getGradient(i);
                V.setGradient(this.switches[i], chain_grad);
            }
        } else {
            int n = 0;
            for (int x = 0; x < V2.getSX(); ++x) {
                for (int y = 0; y < V2.getSY(); ++y) {
                    for (int i = 0; i < N; ++i) {
                        double chain_grad = V2.getGradient(x, y, i);
                        V.setGradient(x, y, this.switches[n], chain_grad);
                        ++n;
                    }
                }
            }
        }
    }

    @Override
    public List<BackPropResult> getBackPropagationResult() {
        return new ArrayList<BackPropResult>();
    }
}

