package org.jdmp.core.algorithm.classification.mlp;

import java.io.Serializable;
import org.jdmp.core.algorithm.classification.mlp.MultiLayerNetwork;
import org.jdmp.core.variable.Variable;
import org.ujmp.core.Matrix;

/* loaded from: input_file:org/jdmp/core/algorithm/classification/mlp/NetworkLayer.class */
public class NetworkLayer implements Serializable {
    private static final long serialVersionUID = 3933537008613985003L;
    private NetworkLayer previousLayer;
    private NetworkLayer nextLayer;
    private NetworkLayerForward algorithmForward;
    private NetworkLayerBackward algorithmBackward;
    private WeightUpdate algorithmWeightUpdate;
    private MultiLayerNetwork.BiasType biasType;
    private MultiLayerNetwork.Aggregation aggregation;

    public NetworkLayer(MultiLayerNetwork.Aggregation aggregation, MultiLayerNetwork.Transfer transfer, MultiLayerNetwork.BiasType biasType) {
        this.previousLayer = null;
        this.nextLayer = null;
        this.algorithmForward = null;
        this.algorithmBackward = null;
        this.algorithmWeightUpdate = null;
        this.biasType = null;
        this.aggregation = null;
        this.algorithmForward = new NetworkLayerForward(aggregation, transfer, biasType);
        this.algorithmBackward = new NetworkLayerBackward(transfer, biasType);
        this.algorithmWeightUpdate = new WeightUpdate(biasType);
        this.algorithmWeightUpdate.setContactDeviationVariable(this.algorithmBackward.getContactDeviationVariable());
        this.biasType = biasType;
        this.aggregation = aggregation;
    }

    public NetworkLayer(MultiLayerNetwork.Aggregation aggregation, MultiLayerNetwork.Transfer transfer, MultiLayerNetwork.BiasType biasType, int i) {
        this(aggregation, transfer, biasType);
        Variable labeledVariable = Variable.Factory.labeledVariable("Output Deviation");
        labeledVariable.add(Matrix.Factory.zeros(i, 1L));
        setOutputDeviationVariable(labeledVariable);
        Variable labeledVariable2 = Variable.Factory.labeledVariable("Output");
        labeledVariable2.add(Matrix.Factory.zeros(i, 1L));
        setOutputVariable(labeledVariable2);
    }

    public void setLayer(int i) {
        this.algorithmBackward.setLayer(i);
        this.algorithmForward.setLayer(i);
        this.algorithmWeightUpdate.setLayer(i);
    }

    public NetworkLayerBackward getAlgorithmBackward() {
        return this.algorithmBackward;
    }

    public void setLearningRate(double d) {
        getAlgorithmWeightUpdate().setLearningRate(d);
    }

    public double getLearningRate() {
        return getAlgorithmWeightUpdate().getLearningRate();
    }

    public void setSampleWeight(double d) {
        getAlgorithmWeightUpdate().setSampleWeight(d);
    }

    public WeightUpdate getAlgorithmWeightUpdate() {
        return this.algorithmWeightUpdate;
    }

    public void setAlgorithmBackward(NetworkLayerBackward networkLayerBackward) {
        this.algorithmBackward = networkLayerBackward;
    }

    public NetworkLayerForward getAlgorithmForward() {
        return this.algorithmForward;
    }

    public void setAlgorithmForward(NetworkLayerForward networkLayerForward) {
        this.algorithmForward = networkLayerForward;
    }

    public NetworkLayer getNextLayer() {
        return this.nextLayer;
    }

    public void setNextLayer(NetworkLayer networkLayer) {
        this.nextLayer = networkLayer;
        networkLayer.setInputVariable(getOutputVariable());
        setOutputDeviationVariable(networkLayer.getInputDeviationVariable());
    }

    public NetworkLayer getPreviousLayer() {
        return this.previousLayer;
    }

    public void setPreviousLayer(NetworkLayer networkLayer) {
        this.previousLayer = networkLayer;
        if (networkLayer != null) {
            setInputVariable(networkLayer.getOutputVariable());
            networkLayer.setOutputDeviationVariable(getInputDeviationVariable());
        }
    }

    public Variable getInputDeviationVariable() {
        return this.algorithmBackward.getInputDeviationVariable();
    }

    public void setWeightVariable(Variable variable) {
        this.algorithmForward.setWeightVariable(variable);
        this.algorithmBackward.setWeightVariable(variable);
        this.algorithmWeightUpdate.setWeightVariable(variable);
    }

    public void calculateForward() {
        createVariablesIfNecessary();
        this.algorithmForward.calculate();
    }

    public void calculateBackward() {
        createVariablesIfNecessary();
        this.algorithmBackward.calculate();
    }

    public void calculateWeightUpdate() {
        this.algorithmWeightUpdate.calculate();
    }

    public Variable getWeightVariable() {
        return this.algorithmForward.getWeightVariable();
    }

    public void createVariablesIfNecessary() {
        if (getWeightVariable() == null) {
            Variable labeledVariable = Variable.Factory.labeledVariable("Weight");
            Matrix matrix = null;
            if (this.aggregation == MultiLayerNetwork.Aggregation.SUM) {
                double rowCount = (1.0d / ((Matrix) getInputVariable().getLast()).getRowCount()) / ((Matrix) getInputVariable().getLast()).getColumnCount();
            }
            switch (this.biasType) {
                case SINGLE:
                    matrix = Matrix.Factory.randn(((Matrix) getOutputVariable().getLast()).getRowCount() * ((Matrix) getOutputVariable().getLast()).getColumnCount(), (((Matrix) getInputVariable().getLast()).getRowCount() * ((Matrix) getInputVariable().getLast()).getColumnCount()) + 1).times(1.0E-4d);
                    break;
                case MULTIPLE:
                    matrix = Matrix.Factory.randn(((Matrix) getOutputVariable().getLast()).getRowCount() * ((Matrix) getOutputVariable().getLast()).getColumnCount(), ((Matrix) getInputVariable().getLast()).getRowCount() * ((Matrix) getInputVariable().getLast()).getColumnCount() * 2).times(1.0E-4d);
                    break;
                case NONE:
                    matrix = Matrix.Factory.randn(((Matrix) getOutputVariable().getLast()).getRowCount() * ((Matrix) getOutputVariable().getLast()).getColumnCount(), ((Matrix) getInputVariable().getLast()).getRowCount() * ((Matrix) getInputVariable().getLast()).getColumnCount()).times(1.0E-4d);
                    break;
            }
            labeledVariable.add(matrix);
            setWeightVariable(labeledVariable);
        }
    }

    public String getLabel() {
        return this.algorithmForward.getLabel();
    }

    public void reset() {
        createVariablesIfNecessary();
        getWeightVariable().add(Matrix.Factory.randn(((Matrix) getOutputVariable().getLast()).getRowCount() * ((Matrix) getOutputVariable().getLast()).getColumnCount(), ((Matrix) getInputVariable().getLast()).getRowCount() * ((Matrix) getInputVariable().getLast()).getColumnCount()).times(1.0E-4d));
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        if (getPreviousLayer() != null) {
            sb.append("(" + getPreviousLayer().getNeuronCount() + " -->) ");
        }
        sb.append(getNeuronCount());
        if (getNextLayer() != null) {
            sb.append(" (--> " + getNextLayer().getNeuronCount() + ")");
        }
        return sb.toString();
    }

    public int getNeuronCount() {
        return (int) ((Matrix) getOutputDeviationVariable().getLast()).getRowCount();
    }

    public void setInputVariable(Variable variable) {
        this.algorithmForward.setInputVariable(variable);
        this.algorithmWeightUpdate.setInputVariable(variable);
    }

    public void setOutputDeviationVariable(Variable variable) {
        this.algorithmBackward.setOutputDeviationVariable(variable);
    }

    public void setOutputVariable(Variable variable) {
        this.algorithmForward.setOutputVariable(variable);
        this.algorithmBackward.setOutputVariable(variable);
    }

    public Variable getInputVariable() {
        return this.algorithmForward.getInputVariable();
    }

    public Variable getOutputDeviationVariable() {
        return this.algorithmBackward.getOutputDeviationVariable();
    }

    public Variable getOutputVariable() {
        return this.algorithmForward.getOutputVariable();
    }

    public void addInputMatrix(Matrix matrix) {
        this.algorithmForward.getVariableMap().setMatrix("Input", matrix);
    }
}
