/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.classification.mlp;

import java.util.HashMap;
import java.util.Map;
import org.jdmp.core.algorithm.AlgorithmFiveSources;
import org.jdmp.core.algorithm.classification.mlp.MultiLayerNetwork;
import org.jdmp.core.variable.Variable;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.doublematrix.DenseDoubleMatrix2D;
import org.ujmp.core.intmatrix.DenseIntMatrix2D;
import org.ujmp.core.util.MathUtil;

public class WeightUpdate
extends AlgorithmFiveSources {
    private static final long serialVersionUID = -6437801270586210637L;
    public static final String WEIGHT = "Source 1";
    public static final String ETA = "Source 2";
    public static final String CONTACTDEVIATION = "Source 3";
    public static final String INPUT = "Source 4";
    public static final String SAMPLEWEIGHT = "Source 5";
    private MultiLayerNetwork.BiasType biasType = null;

    public WeightUpdate(MultiLayerNetwork.BiasType biasType) {
        super(new Variable[0]);
        this.biasType = biasType;
        this.setDescription("weight = weight + eta * sampleweight * inputdeviation * input");
        Variable eta = Variable.Factory.singleValue("eta", 1.0);
        DenseDoubleMatrix2D m = Matrix.Factory.linkToValue(0.001);
        eta.add(m);
        this.setVariable(ETA, eta);
        Variable sampleWeight = Variable.Factory.singleValue("Sample Weight", 1.0);
        DenseIntMatrix2D s = Matrix.Factory.linkToValue(1);
        sampleWeight.add(s);
        this.setVariable(SAMPLEWEIGHT, sampleWeight);
    }

    public void setLearningRate(double v) {
        DenseDoubleMatrix2D m = Matrix.Factory.linkToValue(v);
        this.getVariableMap().setMatrix(ETA, m);
    }

    public double getLearningRate() {
        return this.getVariableMap().getMatrix(ETA).doubleValue();
    }

    public void setSampleWeight(double v) {
        DenseDoubleMatrix2D m = Matrix.Factory.linkToValue(v);
        this.getVariableMap().setMatrix(SAMPLEWEIGHT, m);
    }

    @Override
    public Map<String, Object> calculateObjects(Map<String, Object> matrices) {
        HashMap<String, Object> result = new HashMap<String, Object>();
        Matrix weight = MathUtil.getMatrix(matrices.get(WEIGHT));
        double eta = MathUtil.getMatrix(matrices.get(ETA)).doubleValue();
        Matrix contactDeviation = MathUtil.getMatrix(matrices.get(CONTACTDEVIATION));
        double sampleWeight = MathUtil.getMatrix(matrices.get(SAMPLEWEIGHT)).doubleValue();
        Matrix transposedInput = MathUtil.getMatrix(matrices.get(INPUT)).toColumnVector(Calculation.Ret.NEW);
        switch (this.biasType) {
            case SINGLE: {
                Object bias = Matrix.Factory.ones(transposedInput.getRowCount(), 1L);
                transposedInput = Matrix.Factory.horCat(new Matrix[]{transposedInput, bias});
                break;
            }
            case MULTIPLE: {
                Object bias = Matrix.Factory.ones(transposedInput.getSize());
                for (long[] c : transposedInput.allCoordinates()) {
                    if (!MathUtil.isNaNOrInfinite(transposedInput.getAsDouble(c))) continue;
                    bias.setAsDouble(Double.NaN, c);
                }
                transposedInput = Matrix.Factory.horCat(new Matrix[]{transposedInput, bias});
                break;
            }
        }
        double totalValueCount = transposedInput.getValueCount();
        double missingValueCount = transposedInput.countMissing(Calculation.Ret.NEW, Integer.MAX_VALUE).doubleValue();
        double boost = 1.0;
        boost = totalValueCount / (totalValueCount - missingValueCount);
        Matrix product = contactDeviation.mtimes(transposedInput);
        Matrix weightChange = product.times(eta * sampleWeight * boost);
        Matrix newWeight = weight.minus(Calculation.Ret.NEW, true, weightChange);
        result.put("Target", newWeight);
        return result;
    }

    public void setWeightVariable(Variable v) {
        this.setVariable(WEIGHT, v);
        this.setVariable("Target", v);
    }

    public void setInputVariable(Variable v) {
        this.setVariable(INPUT, v);
    }

    public void setContactDeviationVariable(Variable v) {
        this.setVariable(CONTACTDEVIATION, v);
    }

    public void setLayer(int nr) {
        this.setLabel("Weight Update (" + nr + ")");
    }
}

