package org.jdmp.core.algorithm.classification.mlp;

import java.util.HashMap;
import java.util.Map;
import org.jdmp.core.algorithm.AbstractAlgorithm;
import org.jdmp.core.algorithm.Algorithm;
import org.jdmp.core.algorithm.basic.Clone;
import org.jdmp.core.algorithm.basic.LogisticFunction;
import org.jdmp.core.algorithm.basic.Mean;
import org.jdmp.core.algorithm.basic.Sum;
import org.jdmp.core.algorithm.basic.Tanh;
import org.jdmp.core.algorithm.basic.TanhPlusOne;
import org.jdmp.core.algorithm.classification.mlp.MultiLayerNetwork;
import org.jdmp.core.variable.Variable;
import org.ujmp.core.Matrix;

/* loaded from: input_file:org/jdmp/core/algorithm/classification/mlp/NetworkLayerForward.class */
public class NetworkLayerForward extends AbstractAlgorithm {
    private static final long serialVersionUID = -2738909213636005084L;
    public static final String INPUT = "Input";
    public static final String WEIGHT = "Weight";
    public static final String WEIGHTINGFUNCTION = "WeightingFunction";
    public static final String WEIGHTEDINPUT = "WeightedInput";
    public static final String AGGREGATIONFUNCTION = "AggregationFunction";
    public static final String NETINPUT = "NetInput";
    public static final String TRANSFERFUNCTION = "TransferFunction";
    public static final String OUTPUT = "Output";

    public NetworkLayerForward(MultiLayerNetwork.BiasType biasType) {
        this(MultiLayerNetwork.Aggregation.MEAN, MultiLayerNetwork.Transfer.TANH, biasType);
    }

    public NetworkLayerForward(MultiLayerNetwork.Aggregation aggregation, MultiLayerNetwork.Transfer transfer, MultiLayerNetwork.BiasType biasType) {
        setDescription("One layer of a multi-layer network, forward path");
        setVariable(WEIGHTEDINPUT, Variable.Factory.labeledVariable("Weighted Input"));
        setVariable(NETINPUT, Variable.Factory.labeledVariable("Net Input"));
        Weighting weighting = new Weighting(biasType);
        weighting.setVariable("Target", getWeightedInputVariable());
        setAlgorithm(WEIGHTINGFUNCTION, weighting);
        setAggregationFunction(aggregation);
        setTransferFunction(transfer);
    }

    @Override // org.jdmp.core.algorithm.AbstractAlgorithm, org.jdmp.core.algorithm.Algorithm
    public Map<String, Object> calculateObjects(Map<String, Object> map) {
        HashMap hashMap = new HashMap();
        getWeightingFunction().calculate();
        getAggregationFunction().calculate();
        getTransferFunction().calculate();
        return hashMap;
    }

    public Matrix getInputMatrix() {
        return getVariableMap().getMatrix("Input");
    }

    public Variable getInputVariable() {
        return (Variable) getVariableMap().get("Input");
    }

    public void setInputVariable(Variable variable) {
        setVariable("Input", variable);
        getWeightingFunction().setVariable("Source 2", variable);
    }

    public Matrix getWeightMatrix() {
        return getVariableMap().getMatrix("Weight");
    }

    public Variable getWeightVariable() {
        return (Variable) getVariableMap().get("Weight");
    }

    public void setWeightVariable(Variable variable) {
        setVariable("Weight", variable);
        getWeightingFunction().setVariable("Source 1", variable);
    }

    public Matrix getWeightedInputMatrix() {
        return getVariableMap().getMatrix(WEIGHTEDINPUT);
    }

    public Variable getWeightedInputVariable() {
        return (Variable) getVariableMap().get(WEIGHTEDINPUT);
    }

    public void setAggregationFunction(MultiLayerNetwork.Aggregation aggregation) {
        Algorithm algorithm = null;
        switch (aggregation) {
            case MEAN:
                algorithm = new Mean(1);
                break;
            case SUM:
                algorithm = new Sum(1);
                break;
        }
        algorithm.setVariable("Source", getWeightedInputVariable());
        algorithm.setVariable("Target", getNetInputVariable());
        setAlgorithm(AGGREGATIONFUNCTION, algorithm);
    }

    public void setTransferFunction(MultiLayerNetwork.Transfer transfer) {
        Algorithm algorithm = null;
        switch (transfer) {
            case TANH:
                algorithm = new Tanh(new Variable[0]);
                break;
            case LINEAR:
                algorithm = new Clone(new Variable[0]);
                break;
            case TANHPLUSONE:
                algorithm = new TanhPlusOne(new Variable[0]);
                break;
            case SIGMOID:
                algorithm = new LogisticFunction(new Variable[0]);
                break;
        }
        algorithm.setVariable("Source", getNetInputVariable());
        setAlgorithm(TRANSFERFUNCTION, algorithm);
    }

    public void setWeightedInputVariable(Variable variable) {
        setVariable(WEIGHTEDINPUT, variable);
    }

    public Algorithm getWeightingFunction() {
        return getAlgorithmMap().get(WEIGHTINGFUNCTION);
    }

    public Algorithm getAggregationFunction() {
        return getAlgorithmMap().get(AGGREGATIONFUNCTION);
    }

    public Algorithm getTransferFunction() {
        return getAlgorithmMap().get(TRANSFERFUNCTION);
    }

    public Variable getNetInputVariable() {
        return (Variable) getVariableMap().get(NETINPUT);
    }

    public Matrix getNetInputMatrix() {
        return getVariableMap().getMatrix(NETINPUT);
    }

    public void setNetInputVariable(Variable variable) {
        setVariable(NETINPUT, variable);
    }

    public Variable getOutputVariable() {
        return (Variable) getVariableMap().get("Output");
    }

    public Matrix getOutputMatrix() {
        return getVariableMap().getMatrix("Output");
    }

    public void setOutputVariable(Variable variable) {
        setVariable("Output", variable);
        getAlgorithmMap().get(TRANSFERFUNCTION).setVariable("Target", variable);
    }

    public void setLayer(int i) {
        setLabel("Network Layer Forward (" + i + ")");
    }
}
