package org.ea.javacnn.layers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.ea.javacnn.data.BackPropResult;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.data.OutputDefinition;

/* loaded from: input_file:org/ea/javacnn/layers/FullyConnectedLayer.class */
public class FullyConnectedLayer implements Layer, Serializable {
    private DataBlock in_act;
    private DataBlock out_act;
    private int out_depth;
    private int num_inputs;
    private DataBlock biases;
    private double l1_decay_mul = 0.0d;
    private double l2_decay_mul = 1.0d;
    private final float BIAS_PREF = 0.0f;
    private int out_sx = 1;
    private int out_sy = 1;
    private List<DataBlock> filters = new ArrayList();

    public FullyConnectedLayer(OutputDefinition outputDefinition, int i) {
        this.out_depth = i;
        this.num_inputs = outputDefinition.getOutX() * outputDefinition.getOutY() * outputDefinition.getDepth();
        for (int i2 = 0; i2 < this.out_depth; i2++) {
            this.filters.add(new DataBlock(1, 1, this.num_inputs));
        }
        this.biases = new DataBlock(1, 1, this.out_depth, 0.0f);
        outputDefinition.setOutX(this.out_sx);
        outputDefinition.setOutY(this.out_sy);
        outputDefinition.setDepth(this.out_depth);
    }

    @Override // org.ea.javacnn.layers.Layer
    public DataBlock forward(DataBlock dataBlock, boolean z) {
        this.in_act = dataBlock;
        DataBlock dataBlock2 = new DataBlock(1, 1, this.out_depth, 0.0d);
        double[] weights = dataBlock.getWeights();
        for (int i = 0; i < this.out_depth; i++) {
            double d = 0.0d;
            double[] weights2 = this.filters.get(i).getWeights();
            for (int i2 = 0; i2 < this.num_inputs; i2++) {
                d += weights[i2] * weights2[i2];
            }
            dataBlock2.setWeight(i, d + this.biases.getWeight(i));
        }
        this.out_act = dataBlock2;
        return this.out_act;
    }

    @Override // org.ea.javacnn.layers.Layer
    public void backward() {
        DataBlock dataBlock = this.in_act;
        dataBlock.clearGradient();
        for (int i = 0; i < this.out_depth; i++) {
            DataBlock dataBlock2 = this.filters.get(i);
            double d = this.out_act.getGradients()[i];
            for (int i2 = 0; i2 < this.num_inputs; i2++) {
                dataBlock.addGradient(i2, dataBlock2.getWeight(i2) * d);
                dataBlock2.addGradient(i2, dataBlock.getWeight(i2) * d);
            }
            this.biases.addGradient(i, d);
        }
    }

    @Override // org.ea.javacnn.layers.Layer
    public List<BackPropResult> getBackPropagationResult() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.out_depth; i++) {
            arrayList.add(new BackPropResult(this.filters.get(i).getWeights(), this.filters.get(i).getGradients(), this.l1_decay_mul, this.l2_decay_mul));
        }
        arrayList.add(new BackPropResult(this.biases.getWeights(), this.biases.getGradients(), 0.0d, 0.0d));
        return arrayList;
    }
}
