/*
 * Decompiled with CFR 0.152.
 */
package org.ea.javacnn.layers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.ea.javacnn.data.BackPropResult;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.data.OutputDefinition;
import org.ea.javacnn.layers.Layer;

public class DropoutLayer
implements Layer,
Serializable {
    private int out_depth;
    private int out_sx;
    private int out_sy;
    private DataBlock in_act;
    private DataBlock out_act;
    private final double drop_prob = 0.5;
    private boolean[] dropped;

    public DropoutLayer(OutputDefinition def) {
        this.out_sx = def.getOutX();
        this.out_sy = def.getOutY();
        this.out_depth = def.getDepth();
        this.dropped = new boolean[this.out_sx * this.out_sy * this.out_depth];
    }

    @Override
    public DataBlock forward(DataBlock db, boolean training) {
        this.in_act = db;
        DataBlock V2 = db.clone();
        int N = db.getWeights().length;
        if (training) {
            for (int i = 0; i < N; ++i) {
                if (Math.random() < this.drop_prob) {
                    V2.setWeight(i, 0.0);
                    this.dropped[i] = true;
                    continue;
                }
                this.dropped[i] = false;
            }
        } else {
            for (int i = 0; i < N; ++i) {
                V2.mulGradient(i, this.drop_prob);
            }
        }
        this.out_act = V2;
        return this.out_act;
    }

    @Override
    public void backward() {
        DataBlock V = this.in_act;
        DataBlock chain_grad = this.out_act;
        int N = V.getWeights().length;
        V.clearGradient();
        for (int i = 0; i < N; ++i) {
            if (this.dropped[i]) continue;
            V.setGradient(i, chain_grad.getGradient(i));
        }
    }

    @Override
    public List<BackPropResult> getBackPropagationResult() {
        return new ArrayList<BackPropResult>();
    }
}

