/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.classification.mlp;

import java.io.Serializable;
import org.jdmp.core.algorithm.classification.mlp.MultiLayerNetwork;
import org.jdmp.core.algorithm.classification.mlp.NetworkLayerBackward;
import org.jdmp.core.algorithm.classification.mlp.NetworkLayerForward;
import org.jdmp.core.algorithm.classification.mlp.WeightUpdate;
import org.jdmp.core.variable.Variable;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;

public class NetworkLayer
implements Serializable {
    private static final long serialVersionUID = 3933537008613985003L;
    private NetworkLayer previousLayer = null;
    private NetworkLayer nextLayer = null;
    private NetworkLayerForward algorithmForward = null;
    private NetworkLayerBackward algorithmBackward = null;
    private WeightUpdate algorithmWeightUpdate = null;
    private MultiLayerNetwork.BiasType biasType = null;
    private MultiLayerNetwork.Aggregation aggregation = null;

    public NetworkLayer(MultiLayerNetwork.Aggregation aggregationFunction, MultiLayerNetwork.Transfer transferFunction, MultiLayerNetwork.BiasType biasType) {
        this.algorithmForward = new NetworkLayerForward(aggregationFunction, transferFunction, biasType);
        this.algorithmBackward = new NetworkLayerBackward(transferFunction, biasType);
        this.algorithmWeightUpdate = new WeightUpdate(biasType);
        this.algorithmWeightUpdate.setContactDeviationVariable(this.algorithmBackward.getContactDeviationVariable());
        this.biasType = biasType;
        this.aggregation = aggregationFunction;
    }

    public NetworkLayer(MultiLayerNetwork.Aggregation aggregationFunction, MultiLayerNetwork.Transfer transferFunction, MultiLayerNetwork.BiasType biasType, int neuronCount) {
        this(aggregationFunction, transferFunction, biasType);
        Variable outputDeviation = Variable.Factory.labeledVariable("Output Deviation");
        outputDeviation.add(Matrix.Factory.zeros((long)neuronCount, 1L));
        this.setOutputDeviationVariable(outputDeviation);
        Variable output = Variable.Factory.labeledVariable("Output");
        output.add(Matrix.Factory.zeros((long)neuronCount, 1L));
        this.setOutputVariable(output);
    }

    public void setLayer(int nr) {
        this.algorithmBackward.setLayer(nr);
        this.algorithmForward.setLayer(nr);
        this.algorithmWeightUpdate.setLayer(nr);
    }

    public NetworkLayerBackward getAlgorithmBackward() {
        return this.algorithmBackward;
    }

    public void setLearningRate(double v) {
        this.getAlgorithmWeightUpdate().setLearningRate(v);
    }

    public double getLearningRate() {
        return this.getAlgorithmWeightUpdate().getLearningRate();
    }

    public void setSampleWeight(double v) {
        this.getAlgorithmWeightUpdate().setSampleWeight(v);
    }

    public WeightUpdate getAlgorithmWeightUpdate() {
        return this.algorithmWeightUpdate;
    }

    public void setAlgorithmBackward(NetworkLayerBackward algorithmBackward) {
        this.algorithmBackward = algorithmBackward;
    }

    public NetworkLayerForward getAlgorithmForward() {
        return this.algorithmForward;
    }

    public void setAlgorithmForward(NetworkLayerForward algorithmForward) {
        this.algorithmForward = algorithmForward;
    }

    public NetworkLayer getNextLayer() {
        return this.nextLayer;
    }

    public void setNextLayer(NetworkLayer nextLayer) {
        this.nextLayer = nextLayer;
        nextLayer.setInputVariable(this.getOutputVariable());
        this.setOutputDeviationVariable(nextLayer.getInputDeviationVariable());
    }

    public NetworkLayer getPreviousLayer() {
        return this.previousLayer;
    }

    public void setPreviousLayer(NetworkLayer previousLayer) {
        this.previousLayer = previousLayer;
        if (previousLayer != null) {
            this.setInputVariable(previousLayer.getOutputVariable());
            previousLayer.setOutputDeviationVariable(this.getInputDeviationVariable());
        }
    }

    public Variable getInputDeviationVariable() {
        return this.algorithmBackward.getInputDeviationVariable();
    }

    public void setWeightVariable(Variable weight) {
        this.algorithmForward.setWeightVariable(weight);
        this.algorithmBackward.setWeightVariable(weight);
        this.algorithmWeightUpdate.setWeightVariable(weight);
    }

    public void calculateForward() {
        this.createVariablesIfNecessary();
        this.algorithmForward.calculate();
    }

    public void calculateBackward() {
        this.createVariablesIfNecessary();
        this.algorithmBackward.calculate();
    }

    public void calculateWeightUpdate() {
        this.algorithmWeightUpdate.calculate();
    }

    public Variable getWeightVariable() {
        return this.algorithmForward.getWeightVariable();
    }

    public void createVariablesIfNecessary() {
        if (this.getWeightVariable() == null) {
            Variable weight = Variable.Factory.labeledVariable("Weight");
            Matrix w = null;
            double scale = 1.0;
            if (this.aggregation == MultiLayerNetwork.Aggregation.SUM) {
                scale = 1.0 / (double)((Matrix)this.getInputVariable().getLast()).getRowCount() / (double)((Matrix)this.getInputVariable().getLast()).getColumnCount();
            }
            scale = 1.0E-4;
            switch (this.biasType) {
                case SINGLE: {
                    w = ((DenseMatrix)Matrix.Factory.randn(((Matrix)this.getOutputVariable().getLast()).getRowCount() * ((Matrix)this.getOutputVariable().getLast()).getColumnCount(), ((Matrix)this.getInputVariable().getLast()).getRowCount() * ((Matrix)this.getInputVariable().getLast()).getColumnCount() + 1L)).times(scale);
                    break;
                }
                case MULTIPLE: {
                    w = ((DenseMatrix)Matrix.Factory.randn(((Matrix)this.getOutputVariable().getLast()).getRowCount() * ((Matrix)this.getOutputVariable().getLast()).getColumnCount(), ((Matrix)this.getInputVariable().getLast()).getRowCount() * ((Matrix)this.getInputVariable().getLast()).getColumnCount() * 2L)).times(scale);
                    break;
                }
                case NONE: {
                    w = ((DenseMatrix)Matrix.Factory.randn(((Matrix)this.getOutputVariable().getLast()).getRowCount() * ((Matrix)this.getOutputVariable().getLast()).getColumnCount(), ((Matrix)this.getInputVariable().getLast()).getRowCount() * ((Matrix)this.getInputVariable().getLast()).getColumnCount())).times(scale);
                }
            }
            weight.add(w);
            this.setWeightVariable(weight);
        }
    }

    public String getLabel() {
        return this.algorithmForward.getLabel();
    }

    public void reset() {
        this.createVariablesIfNecessary();
        Matrix w = ((DenseMatrix)Matrix.Factory.randn(((Matrix)this.getOutputVariable().getLast()).getRowCount() * ((Matrix)this.getOutputVariable().getLast()).getColumnCount(), ((Matrix)this.getInputVariable().getLast()).getRowCount() * ((Matrix)this.getInputVariable().getLast()).getColumnCount())).times(1.0E-4);
        this.getWeightVariable().add(w);
    }

    public String toString() {
        StringBuilder s = new StringBuilder();
        if (this.getPreviousLayer() != null) {
            s.append("(" + this.getPreviousLayer().getNeuronCount() + " -->) ");
        }
        s.append(this.getNeuronCount());
        if (this.getNextLayer() != null) {
            s.append(" (--> " + this.getNextLayer().getNeuronCount() + ")");
        }
        return s.toString();
    }

    public int getNeuronCount() {
        return (int)((Matrix)this.getOutputDeviationVariable().getLast()).getRowCount();
    }

    public void setInputVariable(Variable v) {
        this.algorithmForward.setInputVariable(v);
        this.algorithmWeightUpdate.setInputVariable(v);
    }

    public void setOutputDeviationVariable(Variable v) {
        this.algorithmBackward.setOutputDeviationVariable(v);
    }

    public void setOutputVariable(Variable v) {
        this.algorithmForward.setOutputVariable(v);
        this.algorithmBackward.setOutputVariable(v);
    }

    public Variable getInputVariable() {
        return this.algorithmForward.getInputVariable();
    }

    public Variable getOutputDeviationVariable() {
        return this.algorithmBackward.getOutputDeviationVariable();
    }

    public Variable getOutputVariable() {
        return this.algorithmForward.getOutputVariable();
    }

    public void addInputMatrix(Matrix matrix) {
        this.algorithmForward.getVariableMap().setMatrix("Input", matrix);
    }
}

