package org.jdmp.core.algorithm.classification.mlp;

import java.util.HashMap;
import java.util.Map;
import org.jdmp.core.AbstractCoreObject;
import org.jdmp.core.algorithm.AlgorithmFiveSources;
import org.jdmp.core.algorithm.classification.mlp.MultiLayerNetwork;
import org.jdmp.core.variable.Variable;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.util.MathUtil;

/* loaded from: input_file:org/jdmp/core/algorithm/classification/mlp/WeightUpdate.class */
public class WeightUpdate extends AlgorithmFiveSources {
    private static final long serialVersionUID = -6437801270586210637L;
    public static final String WEIGHT = "Source 1";
    public static final String ETA = "Source 2";
    public static final String CONTACTDEVIATION = "Source 3";
    public static final String INPUT = "Source 4";
    public static final String SAMPLEWEIGHT = "Source 5";
    private MultiLayerNetwork.BiasType biasType;

    public WeightUpdate(MultiLayerNetwork.BiasType biasType) {
        super(new Variable[0]);
        this.biasType = null;
        this.biasType = biasType;
        setDescription("weight = weight + eta * sampleweight * inputdeviation * input");
        Variable singleValue = Variable.Factory.singleValue("eta", 1.0d);
        singleValue.add(Matrix.Factory.linkToValue(0.001d));
        setVariable("Source 2", singleValue);
        Variable singleValue2 = Variable.Factory.singleValue("Sample Weight", 1.0d);
        singleValue2.add(Matrix.Factory.linkToValue(1));
        setVariable("Source 5", singleValue2);
    }

    public void setLearningRate(double d) {
        getVariableMap().setMatrix("Source 2", Matrix.Factory.linkToValue(d));
    }

    public double getLearningRate() {
        return getVariableMap().getMatrix("Source 2").doubleValue();
    }

    public void setSampleWeight(double d) {
        getVariableMap().setMatrix("Source 5", Matrix.Factory.linkToValue(d));
    }

    @Override // org.jdmp.core.algorithm.AbstractAlgorithm, org.jdmp.core.algorithm.Algorithm
    public Map<String, Object> calculateObjects(Map<String, Object> map) {
        HashMap hashMap = new HashMap();
        Matrix matrix = MathUtil.getMatrix(map.get("Source 1"));
        double doubleValue = MathUtil.getMatrix(map.get("Source 2")).doubleValue();
        Matrix matrix2 = MathUtil.getMatrix(map.get("Source 3"));
        double doubleValue2 = MathUtil.getMatrix(map.get("Source 5")).doubleValue();
        Matrix columnVector = MathUtil.getMatrix(map.get("Source 4")).toColumnVector(Calculation.Ret.NEW);
        switch (this.biasType) {
            case SINGLE:
                columnVector = Matrix.Factory.horCat(new Matrix[]{columnVector, Matrix.Factory.ones(columnVector.getRowCount(), 1L)});
                break;
            case MULTIPLE:
                Matrix ones = Matrix.Factory.ones(columnVector.getSize());
                for (long[] jArr : columnVector.allCoordinates()) {
                    if (MathUtil.isNaNOrInfinite(columnVector.getAsDouble(jArr))) {
                        ones.setAsDouble(Double.NaN, jArr);
                    }
                }
                columnVector = Matrix.Factory.horCat(new Matrix[]{columnVector, ones});
                break;
        }
        double valueCount = columnVector.getValueCount();
        hashMap.put("Target", matrix.minus(Calculation.Ret.NEW, true, matrix2.mtimes(columnVector).times(doubleValue * doubleValue2 * (valueCount / (valueCount - columnVector.countMissing(Calculation.Ret.NEW, AbstractCoreObject.ALL).doubleValue())))));
        return hashMap;
    }

    public void setWeightVariable(Variable variable) {
        setVariable("Source 1", variable);
        setVariable("Target", variable);
    }

    public void setInputVariable(Variable variable) {
        setVariable("Source 4", variable);
    }

    public void setContactDeviationVariable(Variable variable) {
        setVariable("Source 3", variable);
    }

    public void setLayer(int i) {
        setLabel("Weight Update (" + i + ")");
    }
}
