/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.classification.mlp;

import java.util.HashMap;
import java.util.Map;
import org.jdmp.core.algorithm.AbstractAlgorithm;
import org.jdmp.core.algorithm.Algorithm;
import org.jdmp.core.algorithm.basic.Clone;
import org.jdmp.core.algorithm.basic.LogisticFunction;
import org.jdmp.core.algorithm.basic.Mean;
import org.jdmp.core.algorithm.basic.Sum;
import org.jdmp.core.algorithm.basic.Tanh;
import org.jdmp.core.algorithm.basic.TanhPlusOne;
import org.jdmp.core.algorithm.classification.mlp.MultiLayerNetwork;
import org.jdmp.core.algorithm.classification.mlp.Weighting;
import org.jdmp.core.variable.Variable;
import org.ujmp.core.Matrix;

public class NetworkLayerForward
extends AbstractAlgorithm {
    private static final long serialVersionUID = -2738909213636005084L;
    public static final String INPUT = "Input";
    public static final String WEIGHT = "Weight";
    public static final String WEIGHTINGFUNCTION = "WeightingFunction";
    public static final String WEIGHTEDINPUT = "WeightedInput";
    public static final String AGGREGATIONFUNCTION = "AggregationFunction";
    public static final String NETINPUT = "NetInput";
    public static final String TRANSFERFUNCTION = "TransferFunction";
    public static final String OUTPUT = "Output";

    public NetworkLayerForward(MultiLayerNetwork.BiasType biasType) {
        this(MultiLayerNetwork.Aggregation.MEAN, MultiLayerNetwork.Transfer.TANH, biasType);
    }

    public NetworkLayerForward(MultiLayerNetwork.Aggregation aggregationFunction, MultiLayerNetwork.Transfer transferFunction, MultiLayerNetwork.BiasType biasType) {
        this.setDescription("One layer of a multi-layer network, forward path");
        this.setVariable(WEIGHTEDINPUT, Variable.Factory.labeledVariable("Weighted Input"));
        this.setVariable(NETINPUT, Variable.Factory.labeledVariable("Net Input"));
        Weighting aw = new Weighting(biasType);
        aw.setVariable("Target", this.getWeightedInputVariable());
        this.setAlgorithm(WEIGHTINGFUNCTION, aw);
        this.setAggregationFunction(aggregationFunction);
        this.setTransferFunction(transferFunction);
    }

    @Override
    public Map<String, Object> calculateObjects(Map<String, Object> input) {
        HashMap<String, Object> result = new HashMap<String, Object>();
        Algorithm weightingFunction = this.getWeightingFunction();
        weightingFunction.calculate();
        Algorithm aggregationFunction = this.getAggregationFunction();
        aggregationFunction.calculate();
        Algorithm transferFunction = this.getTransferFunction();
        transferFunction.calculate();
        return result;
    }

    public Matrix getInputMatrix() {
        return this.getVariableMap().getMatrix(INPUT);
    }

    public Variable getInputVariable() {
        Variable v = (Variable)this.getVariableMap().get(INPUT);
        return v;
    }

    public void setInputVariable(Variable v) {
        this.setVariable(INPUT, v);
        this.getWeightingFunction().setVariable("Source 2", v);
    }

    public Matrix getWeightMatrix() {
        return this.getVariableMap().getMatrix(WEIGHT);
    }

    public Variable getWeightVariable() {
        Variable v = (Variable)this.getVariableMap().get(WEIGHT);
        return v;
    }

    public void setWeightVariable(Variable v) {
        this.setVariable(WEIGHT, v);
        this.getWeightingFunction().setVariable("Source 1", v);
    }

    public Matrix getWeightedInputMatrix() {
        return this.getVariableMap().getMatrix(WEIGHTEDINPUT);
    }

    public Variable getWeightedInputVariable() {
        Variable v = (Variable)this.getVariableMap().get(WEIGHTEDINPUT);
        return v;
    }

    public void setAggregationFunction(MultiLayerNetwork.Aggregation aggregationFunction) {
        AbstractAlgorithm a = null;
        switch (aggregationFunction) {
            case MEAN: {
                a = new Mean(1);
                break;
            }
            case SUM: {
                a = new Sum(1);
            }
        }
        a.setVariable("Source", this.getWeightedInputVariable());
        a.setVariable("Target", this.getNetInputVariable());
        this.setAlgorithm(AGGREGATIONFUNCTION, a);
    }

    public void setTransferFunction(MultiLayerNetwork.Transfer transferFunction) {
        Algorithm a = null;
        switch (transferFunction) {
            case TANH: {
                a = new Tanh(new Variable[0]);
                break;
            }
            case LINEAR: {
                a = new Clone(new Variable[0]);
                break;
            }
            case TANHPLUSONE: {
                a = new TanhPlusOne(new Variable[0]);
                break;
            }
            case SIGMOID: {
                a = new LogisticFunction(new Variable[0]);
            }
        }
        a.setVariable("Source", this.getNetInputVariable());
        this.setAlgorithm(TRANSFERFUNCTION, a);
    }

    public void setWeightedInputVariable(Variable v) {
        this.setVariable(WEIGHTEDINPUT, v);
    }

    public Algorithm getWeightingFunction() {
        Algorithm a = (Algorithm)this.getAlgorithmMap().get(WEIGHTINGFUNCTION);
        return a;
    }

    public Algorithm getAggregationFunction() {
        Algorithm a = (Algorithm)this.getAlgorithmMap().get(AGGREGATIONFUNCTION);
        return a;
    }

    public Algorithm getTransferFunction() {
        Algorithm a = (Algorithm)this.getAlgorithmMap().get(TRANSFERFUNCTION);
        return a;
    }

    public Variable getNetInputVariable() {
        Variable v = (Variable)this.getVariableMap().get(NETINPUT);
        return v;
    }

    public Matrix getNetInputMatrix() {
        return this.getVariableMap().getMatrix(NETINPUT);
    }

    public void setNetInputVariable(Variable v) {
        this.setVariable(NETINPUT, v);
    }

    public Variable getOutputVariable() {
        Variable v = (Variable)this.getVariableMap().get(OUTPUT);
        return v;
    }

    public Matrix getOutputMatrix() {
        return this.getVariableMap().getMatrix(OUTPUT);
    }

    public void setOutputVariable(Variable v) {
        this.setVariable(OUTPUT, v);
        ((Algorithm)this.getAlgorithmMap().get(TRANSFERFUNCTION)).setVariable("Target", v);
    }

    public void setLayer(int nr) {
        this.setLabel("Network Layer Forward (" + nr + ")");
    }
}

