/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.classification.mlp;

import java.util.HashMap;
import java.util.Map;
import org.jdmp.core.algorithm.AbstractAlgorithm;
import org.jdmp.core.algorithm.Algorithm;
import org.jdmp.core.algorithm.AlgorithmTwoSources;
import org.jdmp.core.algorithm.basic.Clone;
import org.jdmp.core.algorithm.classification.mlp.DimmingFunction;
import org.jdmp.core.algorithm.classification.mlp.MultiLayerNetwork;
import org.jdmp.core.algorithm.classification.mlp.ReturningFunctionOne;
import org.jdmp.core.algorithm.classification.mlp.ReturningFunctionTanh;
import org.jdmp.core.variable.Variable;
import org.ujmp.core.Matrix;
import org.ujmp.core.doublematrix.DenseDoubleMatrix2D;

public class NetworkLayerBackward
extends AbstractAlgorithm {
    private static final long serialVersionUID = -8051094919417324985L;
    public static final String WEIGHT = "Weight";
    public static final String OUTPUT = "Output";
    public static final String OUTPUTDEVIATION = "OutputDeviation";
    public static final String RETURNINGFUNCTION = "ReturningFunction";
    public static final String NETDEVIATION = "NetDeviation";
    public static final String SPLITTINGFUNCTION = "SplittingFunction";
    public static final String CONTACTDEVIATION = "ContactDeviation";
    public static final String DIMMINGFUNCTION = "DimmingFunction";
    public static final String INPUTDEVIATION = "InputDeviation";

    public NetworkLayerBackward(MultiLayerNetwork.Transfer transferFunction, MultiLayerNetwork.BiasType biasType) {
        this.setDescription("One layer of a multi-layer network, backward path");
        this.setVariable(NETDEVIATION, Variable.Factory.labeledVariable("Net Deviation"));
        this.setVariable(CONTACTDEVIATION, Variable.Factory.labeledVariable("Contact Deviation"));
        this.setVariable(INPUTDEVIATION, Variable.Factory.labeledVariable("Input Deviation"));
        AlgorithmTwoSources ar = null;
        switch (transferFunction) {
            case TANH: {
                ar = new ReturningFunctionTanh();
                break;
            }
            case TANHPLUSONE: {
                ar = new ReturningFunctionTanh();
                break;
            }
            case LINEAR: {
                ar = new ReturningFunctionOne();
                break;
            }
            default: {
                throw new RuntimeException("not implemented: " + (Object)((Object)transferFunction));
            }
        }
        ar.setVariable("Target", this.getNetDeviationVariable());
        this.setAlgorithm(RETURNINGFUNCTION, ar);
        Clone as = new Clone(new Variable[0]);
        as.setVariable("Source", this.getNetDeviationVariable());
        as.setVariable("Target", this.getContactDeviationVariable());
        this.setAlgorithm(SPLITTINGFUNCTION, as);
        DimmingFunction ad = new DimmingFunction(biasType);
        ad.setVariable("Source 2", this.getContactDeviationVariable());
        ad.setVariable("Target", this.getInputDeviationVariable());
        this.setAlgorithm(DIMMINGFUNCTION, ad);
    }

    public static void main(String[] args) throws Exception {
        NetworkLayerBackward a = new NetworkLayerBackward(MultiLayerNetwork.Transfer.TANH, MultiLayerNetwork.BiasType.SINGLE);
        Variable outputDeviation = Variable.Factory.labeledVariable("Output Deviation");
        DenseDoubleMatrix2D d = Matrix.Factory.linkToArray(new double[][]{{1.0}, {2.0}, {3.0}, {4.0}, {5.0}, {6.0}});
        outputDeviation.add(d);
        a.setOutputDeviationVariable(outputDeviation);
        Variable output = Variable.Factory.labeledVariable(OUTPUT);
        DenseDoubleMatrix2D o = Matrix.Factory.linkToArray(new double[][]{{5.0}, {6.0}, {7.0}, {8.0}, {9.0}, {10.0}});
        output.add(o);
        a.setOutputVariable(output);
        Variable weight = Variable.Factory.labeledVariable(WEIGHT);
        DenseDoubleMatrix2D w = Matrix.Factory.linkToArray(new double[][]{{0.1, 0.2, 0.3, 0.4}, {0.1, 0.2, 0.3, 0.4}, {0.1, 0.2, 0.3, 0.4}, {0.1, 0.2, 0.3, 0.4}, {0.1, 0.2, 0.3, 0.4}, {0.1, 0.2, 0.3, 0.4}});
        weight.add(w);
        a.setWeightVariable(weight);
        System.out.println(a.calculate());
    }

    @Override
    public Map<String, Object> calculateObjects(Map<String, Object> input) {
        HashMap<String, Object> result = new HashMap<String, Object>();
        Algorithm returningFunction = this.getReturningFunction();
        returningFunction.calculate();
        Algorithm splittingFunction = this.getSplittingFunction();
        splittingFunction.calculate();
        Algorithm dimmingFunction = this.getDimmingFunction();
        dimmingFunction.calculate();
        return result;
    }

    public Variable getOutputDeviationVariable() {
        Variable v = (Variable)this.getVariableMap().get(OUTPUTDEVIATION);
        return v;
    }

    public void setOutputDeviationVariable(Variable v) {
        this.setVariable(OUTPUTDEVIATION, v);
        this.getReturningFunction().setVariable("Source 2", v);
    }

    public Matrix getOutputDeviationMatrix() {
        return this.getVariableMap().getMatrix(OUTPUTDEVIATION);
    }

    public Algorithm getReturningFunction() {
        Algorithm a = (Algorithm)this.getAlgorithmMap().get(RETURNINGFUNCTION);
        return a;
    }

    public void setReturningFunction(Algorithm a) {
        this.setAlgorithm(RETURNINGFUNCTION, a);
    }

    public Algorithm getSplittingFunction() {
        Algorithm a = (Algorithm)this.getAlgorithmMap().get(SPLITTINGFUNCTION);
        return a;
    }

    public void setSplittingFunction(Algorithm a) {
        this.setAlgorithm(SPLITTINGFUNCTION, a);
    }

    public Algorithm getDimmingFunction() {
        Algorithm a = (Algorithm)this.getAlgorithmMap().get(DIMMINGFUNCTION);
        return a;
    }

    public void setDimmingFunction(Algorithm a) {
        this.setAlgorithm(DIMMINGFUNCTION, a);
    }

    public Variable getNetDeviationVariable() {
        Variable v = (Variable)this.getVariableMap().get(NETDEVIATION);
        return v;
    }

    public void setNetDeviationVariable(Variable v) {
        this.setVariable(NETDEVIATION, v);
    }

    public Matrix getNetDeviationMatrix() {
        return this.getVariableMap().getMatrix(NETDEVIATION);
    }

    public Variable getContactDeviationVariable() {
        Variable v = (Variable)this.getVariableMap().get(CONTACTDEVIATION);
        return v;
    }

    public void setContactDeviationVariable(Variable v) {
        this.setVariable(CONTACTDEVIATION, v);
    }

    public Matrix getContactDeviationMatrix() {
        return this.getVariableMap().getMatrix(CONTACTDEVIATION);
    }

    public Variable getInputDeviationVariable() {
        Variable v = (Variable)this.getVariableMap().get(INPUTDEVIATION);
        return v;
    }

    public void setInputDeviationVariable(Variable v) {
        this.setVariable(INPUTDEVIATION, v);
    }

    public Matrix getInputDeviationMatrix() {
        return this.getVariableMap().getMatrix(INPUTDEVIATION);
    }

    public Variable getWeightVariable() {
        Variable v = (Variable)this.getVariableMap().get(WEIGHT);
        return v;
    }

    public void setWeightVariable(Variable v) {
        this.setVariable(WEIGHT, v);
        this.getDimmingFunction().setVariable("Source 1", v);
    }

    public Matrix getWeightMatrix() {
        return this.getVariableMap().getMatrix(WEIGHT);
    }

    public Variable getOutputVariable() {
        Variable v = (Variable)this.getVariableMap().get(OUTPUT);
        return v;
    }

    public void setOutputVariable(Variable v) {
        this.setVariable(OUTPUT, v);
        this.getReturningFunction().setVariable("Source 1", v);
    }

    public Matrix getOutputMatrix() {
        return this.getVariableMap().getMatrix(OUTPUT);
    }

    public void setLayer(int nr) {
        this.setLabel("Network Layer Backward (" + nr + ")");
    }
}

