package org.jdmp.core.algorithm.classification.mlp;

import java.util.HashMap;
import java.util.Map;
import org.jdmp.core.algorithm.AlgorithmTwoSources;
import org.jdmp.core.algorithm.classification.mlp.MultiLayerNetwork;
import org.jdmp.core.variable.Variable;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.doublematrix.impl.ArrayDenseDoubleMatrix2D;
import org.ujmp.core.util.MathUtil;

/* loaded from: input_file:org/jdmp/core/algorithm/classification/mlp/Weighting.class */
public class Weighting extends AlgorithmTwoSources {
    private static final long serialVersionUID = -7663880296216560217L;
    private MultiLayerNetwork.BiasType biasType;

    public Weighting(MultiLayerNetwork.BiasType biasType) {
        super(new Variable[0]);
        this.biasType = null;
        setDescription("target_{ji} = weight_{ji} * x_i");
        this.biasType = biasType;
    }

    @Override // org.jdmp.core.algorithm.AbstractAlgorithm, org.jdmp.core.algorithm.Algorithm
    public Map<String, Object> calculateObjects(Map<String, Object> map) {
        HashMap hashMap = new HashMap();
        Matrix matrix = MathUtil.getMatrix(map.get("Source 1"));
        Matrix rowVector = MathUtil.getMatrix(map.get("Source 2")).toRowVector(Calculation.Ret.NEW);
        switch (this.biasType) {
            case SINGLE:
                rowVector = Matrix.Factory.vertCat(new Matrix[]{rowVector, Matrix.Factory.ones(1L, rowVector.getColumnCount())});
                break;
            case MULTIPLE:
                Matrix ones = Matrix.Factory.ones(rowVector.getSize());
                for (long[] jArr : rowVector.allCoordinates()) {
                    if (MathUtil.isNaNOrInfinite(rowVector.getAsDouble(jArr))) {
                        ones.setAsDouble(Double.NaN, jArr);
                    }
                }
                rowVector = Matrix.Factory.horCat(new Matrix[]{rowVector, ones});
                break;
        }
        double[][] dArr = new double[(int) matrix.getSize()[0]][(int) matrix.getSize()[1]];
        int columnCount = (int) matrix.getColumnCount();
        while (true) {
            columnCount--;
            if (columnCount < 0) {
                hashMap.put("Target", new ArrayDenseDoubleMatrix2D(dArr));
                return hashMap;
            }
            double asDouble = rowVector.getAsDouble(new long[]{columnCount, 0});
            int rowCount = (int) matrix.getRowCount();
            while (true) {
                rowCount--;
                if (rowCount >= 0) {
                    dArr[rowCount][columnCount] = matrix.getAsDouble(new long[]{rowCount, columnCount}) * asDouble;
                }
            }
        }
    }
}
