/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.classification.mlp;

import java.util.HashMap;
import java.util.Map;
import org.jdmp.core.algorithm.AlgorithmTwoSources;
import org.jdmp.core.algorithm.classification.mlp.MultiLayerNetwork;
import org.jdmp.core.variable.Variable;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.doublematrix.impl.ArrayDenseDoubleMatrix2D;
import org.ujmp.core.util.MathUtil;

public class Weighting
extends AlgorithmTwoSources {
    private static final long serialVersionUID = -7663880296216560217L;
    private MultiLayerNetwork.BiasType biasType = null;

    public Weighting(MultiLayerNetwork.BiasType biasType) {
        super(new Variable[0]);
        this.setDescription("target_{ji} = weight_{ji} * x_i");
        this.biasType = biasType;
    }

    @Override
    public Map<String, Object> calculateObjects(Map<String, Object> input) {
        HashMap<String, Object> result = new HashMap<String, Object>();
        Matrix weight = MathUtil.getMatrix(input.get("Source 1"));
        Matrix x = MathUtil.getMatrix(input.get("Source 2")).toRowVector(Calculation.Ret.NEW);
        switch (this.biasType) {
            case SINGLE: {
                Object bias = Matrix.Factory.ones(1L, x.getColumnCount());
                x = Matrix.Factory.vertCat(new Matrix[]{x, bias});
                break;
            }
            case MULTIPLE: {
                Object bias = Matrix.Factory.ones(x.getSize());
                for (long[] c : x.allCoordinates()) {
                    if (!MathUtil.isNaNOrInfinite(x.getAsDouble(c))) continue;
                    bias.setAsDouble(Double.NaN, c);
                }
                x = Matrix.Factory.horCat(new Matrix[]{x, bias});
                break;
            }
        }
        double[][] target = new double[(int)weight.getSize()[0]][(int)weight.getSize()[1]];
        double xv = 0.0;
        int c = (int)weight.getColumnCount();
        while (--c >= 0) {
            xv = x.getAsDouble(c, 0L);
            int r = (int)weight.getRowCount();
            while (--r >= 0) {
                target[r][c] = weight.getAsDouble(r, c) * xv;
            }
        }
        result.put("Target", new ArrayDenseDoubleMatrix2D(target));
        return result;
    }
}

