package org.jdmp.core.algorithm.classification.mlp;

import java.util.HashMap;
import java.util.Map;
import org.jdmp.core.algorithm.AbstractAlgorithm;
import org.jdmp.core.algorithm.Algorithm;
import org.jdmp.core.algorithm.basic.Clone;
import org.jdmp.core.algorithm.classification.mlp.MultiLayerNetwork;
import org.jdmp.core.variable.Variable;
import org.ujmp.core.Matrix;

/* loaded from: input_file:org/jdmp/core/algorithm/classification/mlp/NetworkLayerBackward.class */
public class NetworkLayerBackward extends AbstractAlgorithm {
    private static final long serialVersionUID = -8051094919417324985L;
    public static final String WEIGHT = "Weight";
    public static final String OUTPUT = "Output";
    public static final String OUTPUTDEVIATION = "OutputDeviation";
    public static final String RETURNINGFUNCTION = "ReturningFunction";
    public static final String NETDEVIATION = "NetDeviation";
    public static final String SPLITTINGFUNCTION = "SplittingFunction";
    public static final String CONTACTDEVIATION = "ContactDeviation";
    public static final String DIMMINGFUNCTION = "DimmingFunction";
    public static final String INPUTDEVIATION = "InputDeviation";

    public NetworkLayerBackward(MultiLayerNetwork.Transfer transfer, MultiLayerNetwork.BiasType biasType) {
        Algorithm returningFunctionOne;
        setDescription("One layer of a multi-layer network, backward path");
        setVariable(NETDEVIATION, Variable.Factory.labeledVariable("Net Deviation"));
        setVariable(CONTACTDEVIATION, Variable.Factory.labeledVariable("Contact Deviation"));
        setVariable(INPUTDEVIATION, Variable.Factory.labeledVariable("Input Deviation"));
        switch (transfer) {
            case TANH:
                returningFunctionOne = new ReturningFunctionTanh();
                break;
            case TANHPLUSONE:
                returningFunctionOne = new ReturningFunctionTanh();
                break;
            case LINEAR:
                returningFunctionOne = new ReturningFunctionOne();
                break;
            default:
                throw new RuntimeException("not implemented: " + transfer);
        }
        returningFunctionOne.setVariable("Target", getNetDeviationVariable());
        setAlgorithm(RETURNINGFUNCTION, returningFunctionOne);
        Clone clone = new Clone(new Variable[0]);
        clone.setVariable("Source", getNetDeviationVariable());
        clone.setVariable("Target", getContactDeviationVariable());
        setAlgorithm(SPLITTINGFUNCTION, clone);
        DimmingFunction dimmingFunction = new DimmingFunction(biasType);
        dimmingFunction.setVariable("Source 2", getContactDeviationVariable());
        dimmingFunction.setVariable("Target", getInputDeviationVariable());
        setAlgorithm(DIMMINGFUNCTION, dimmingFunction);
    }

    /* JADX WARN: Type inference failed for: r1v13, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v3, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v8, types: [double[], double[][]] */
    public static void main(String[] strArr) throws Exception {
        NetworkLayerBackward networkLayerBackward = new NetworkLayerBackward(MultiLayerNetwork.Transfer.TANH, MultiLayerNetwork.BiasType.SINGLE);
        Variable labeledVariable = Variable.Factory.labeledVariable("Output Deviation");
        labeledVariable.add(Matrix.Factory.linkToArray((double[][]) new double[]{new double[]{1.0d}, new double[]{2.0d}, new double[]{3.0d}, new double[]{4.0d}, new double[]{5.0d}, new double[]{6.0d}}));
        networkLayerBackward.setOutputDeviationVariable(labeledVariable);
        Variable labeledVariable2 = Variable.Factory.labeledVariable("Output");
        labeledVariable2.add(Matrix.Factory.linkToArray((double[][]) new double[]{new double[]{5.0d}, new double[]{6.0d}, new double[]{7.0d}, new double[]{8.0d}, new double[]{9.0d}, new double[]{10.0d}}));
        networkLayerBackward.setOutputVariable(labeledVariable2);
        Variable labeledVariable3 = Variable.Factory.labeledVariable("Weight");
        labeledVariable3.add(Matrix.Factory.linkToArray((double[][]) new double[]{new double[]{0.1d, 0.2d, 0.3d, 0.4d}, new double[]{0.1d, 0.2d, 0.3d, 0.4d}, new double[]{0.1d, 0.2d, 0.3d, 0.4d}, new double[]{0.1d, 0.2d, 0.3d, 0.4d}, new double[]{0.1d, 0.2d, 0.3d, 0.4d}, new double[]{0.1d, 0.2d, 0.3d, 0.4d}}));
        networkLayerBackward.setWeightVariable(labeledVariable3);
        System.out.println(networkLayerBackward.calculate());
    }

    @Override // org.jdmp.core.algorithm.AbstractAlgorithm, org.jdmp.core.algorithm.Algorithm
    public Map<String, Object> calculateObjects(Map<String, Object> map) {
        HashMap hashMap = new HashMap();
        getReturningFunction().calculate();
        getSplittingFunction().calculate();
        getDimmingFunction().calculate();
        return hashMap;
    }

    public Variable getOutputDeviationVariable() {
        return (Variable) getVariableMap().get(OUTPUTDEVIATION);
    }

    public void setOutputDeviationVariable(Variable variable) {
        setVariable(OUTPUTDEVIATION, variable);
        getReturningFunction().setVariable("Source 2", variable);
    }

    public Matrix getOutputDeviationMatrix() {
        return getVariableMap().getMatrix(OUTPUTDEVIATION);
    }

    public Algorithm getReturningFunction() {
        return getAlgorithmMap().get(RETURNINGFUNCTION);
    }

    public void setReturningFunction(Algorithm algorithm) {
        setAlgorithm(RETURNINGFUNCTION, algorithm);
    }

    public Algorithm getSplittingFunction() {
        return getAlgorithmMap().get(SPLITTINGFUNCTION);
    }

    public void setSplittingFunction(Algorithm algorithm) {
        setAlgorithm(SPLITTINGFUNCTION, algorithm);
    }

    public Algorithm getDimmingFunction() {
        return getAlgorithmMap().get(DIMMINGFUNCTION);
    }

    public void setDimmingFunction(Algorithm algorithm) {
        setAlgorithm(DIMMINGFUNCTION, algorithm);
    }

    public Variable getNetDeviationVariable() {
        return (Variable) getVariableMap().get(NETDEVIATION);
    }

    public void setNetDeviationVariable(Variable variable) {
        setVariable(NETDEVIATION, variable);
    }

    public Matrix getNetDeviationMatrix() {
        return getVariableMap().getMatrix(NETDEVIATION);
    }

    public Variable getContactDeviationVariable() {
        return (Variable) getVariableMap().get(CONTACTDEVIATION);
    }

    public void setContactDeviationVariable(Variable variable) {
        setVariable(CONTACTDEVIATION, variable);
    }

    public Matrix getContactDeviationMatrix() {
        return getVariableMap().getMatrix(CONTACTDEVIATION);
    }

    public Variable getInputDeviationVariable() {
        return (Variable) getVariableMap().get(INPUTDEVIATION);
    }

    public void setInputDeviationVariable(Variable variable) {
        setVariable(INPUTDEVIATION, variable);
    }

    public Matrix getInputDeviationMatrix() {
        return getVariableMap().getMatrix(INPUTDEVIATION);
    }

    public Variable getWeightVariable() {
        return (Variable) getVariableMap().get("Weight");
    }

    public void setWeightVariable(Variable variable) {
        setVariable("Weight", variable);
        getDimmingFunction().setVariable("Source 1", variable);
    }

    public Matrix getWeightMatrix() {
        return getVariableMap().getMatrix("Weight");
    }

    public Variable getOutputVariable() {
        return (Variable) getVariableMap().get("Output");
    }

    public void setOutputVariable(Variable variable) {
        setVariable("Output", variable);
        getReturningFunction().setVariable("Source 1", variable);
    }

    public Matrix getOutputMatrix() {
        return getVariableMap().getMatrix("Output");
    }

    public void setLayer(int i) {
        setLabel("Network Layer Backward (" + i + ")");
    }
}
