/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.contrib.matrixmlp;

import org.neuroph.contrib.matrixmlp.MatrixLayer;
import org.neuroph.contrib.matrixmlp.MatrixMlpLayer;
import org.neuroph.contrib.matrixmlp.MatrixMultiLayerPerceptron;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.transfer.TransferFunction;
import org.neuroph.nnet.learning.MomentumBackpropagation;

public class MatrixMomentumBackpropagation
extends MomentumBackpropagation {
    private MatrixMultiLayerPerceptron matrixMlp;
    private MatrixLayer[] matrixLayers;

    @Override
    public void setNeuralNetwork(NeuralNetwork neuralNetwork) {
        super.setNeuralNetwork(neuralNetwork);
        this.matrixMlp = (MatrixMultiLayerPerceptron)this.getNeuralNetwork();
        this.matrixLayers = this.matrixMlp.getMatrixLayers();
    }

    @Override
    protected void calculateErrorAndUpdateOutputNeurons(double[] patternError) {
        MatrixMlpLayer outputLayer = (MatrixMlpLayer)this.matrixLayers[this.matrixLayers.length - 1];
        TransferFunction transferFunction = outputLayer.getTransferFunction();
        double[] outputs = outputLayer.getOutputs();
        double[] netInputs = outputLayer.getNetInput();
        double[] neuronErrors = outputLayer.getErrors();
        for (int i = 0; i < outputs.length; ++i) {
            neuronErrors[i] = patternError[i] * transferFunction.getDerivative(netInputs[i]);
        }
        this.updateLayerWeights(outputLayer, neuronErrors);
        System.out.println("MSE:" + this.getErrorFunction().getTotalError());
    }

    protected void updateLayerWeights(MatrixMlpLayer layer, double[] errors) {
        double[] inputs = layer.getInputs();
        double[][] weights = layer.getWeights();
        double[][] deltaWeights = layer.getDeltaWeights();
        for (int neuronIdx = 0; neuronIdx < layer.getNeuronsCount(); ++neuronIdx) {
            int weightIdx = 0;
            while (weightIdx < weights[neuronIdx].length) {
                double deltaWeight;
                deltaWeights[neuronIdx][weightIdx] = deltaWeight = this.learningRate * errors[neuronIdx] * inputs[weightIdx] + this.momentum * deltaWeights[neuronIdx][weightIdx];
                double[] dArray = weights[neuronIdx];
                int n = weightIdx++;
                dArray[n] = dArray[n] + deltaWeight;
            }
        }
    }

    @Override
    protected void calculateErrorAndUpdateHiddenNeurons() {
        int layersCount = this.matrixMlp.getLayersCount();
        for (int layerIdx = layersCount - 2; layerIdx > 0; --layerIdx) {
            MatrixMlpLayer currentLayer = (MatrixMlpLayer)this.matrixLayers[layerIdx];
            TransferFunction transferFunction = currentLayer.getTransferFunction();
            int neuronsCount = currentLayer.getNeuronsCount();
            double[] neuronErrors = currentLayer.getErrors();
            double[] netInputs = currentLayer.getNetInput();
            MatrixMlpLayer nextLayer = (MatrixMlpLayer)currentLayer.getNextLayer();
            double[] nextLayerErrors = nextLayer.getErrors();
            double[][] nextLayerWeights = nextLayer.getWeights();
            for (int neuronIdx = 0; neuronIdx < neuronsCount; ++neuronIdx) {
                double weightedErrorsSum = 0.0;
                for (int nextLayerNeuronIdx = 0; nextLayerNeuronIdx < nextLayer.getNeuronsCount(); ++nextLayerNeuronIdx) {
                    weightedErrorsSum += nextLayerErrors[nextLayerNeuronIdx] * nextLayerWeights[nextLayerNeuronIdx][neuronIdx];
                }
                neuronErrors[neuronIdx] = transferFunction.getDerivative(netInputs[neuronIdx]) * weightedErrorsSum;
            }
            this.updateLayerWeights(currentLayer, neuronErrors);
        }
    }
}

