/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.nnet.learning;

import java.util.List;
import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.transfer.TransferFunction;
import org.neuroph.nnet.learning.LMS;

public class BackPropagation
extends LMS {
    private static final long serialVersionUID = 1L;

    @Override
    protected void calculateWeightChanges(double[] outputError) {
        this.calculateErrorAndUpdateOutputNeurons(outputError);
        this.calculateErrorAndUpdateHiddenNeurons();
    }

    protected void calculateErrorAndUpdateOutputNeurons(double[] outputError) {
        int i = 0;
        List<Neuron> outputNeurons = this.neuralNetwork.getOutputNeurons();
        for (Neuron neuron : outputNeurons) {
            if (outputError[i] == 0.0) {
                neuron.setDelta(0.0);
                ++i;
                continue;
            }
            TransferFunction transferFunction = neuron.getTransferFunction();
            double neuronInput = neuron.getNetInput();
            double delta = outputError[i] * transferFunction.getDerivative(neuronInput);
            neuron.setDelta(delta);
            this.calculateWeightChanges(neuron);
            ++i;
        }
    }

    protected void calculateErrorAndUpdateHiddenNeurons() {
        List<Layer> layers = this.neuralNetwork.getLayers();
        for (int layerIdx = layers.size() - 2; layerIdx > 0; --layerIdx) {
            for (Neuron neuron : layers.get(layerIdx).getNeurons()) {
                double delta = this.calculateHiddenNeuronError(neuron);
                neuron.setDelta(delta);
                this.calculateWeightChanges(neuron);
            }
        }
    }

    protected double calculateHiddenNeuronError(Neuron neuron) {
        double deltaSum = 0.0;
        for (Connection connection : neuron.getOutConnections()) {
            double delta = connection.getToNeuron().getDelta() * connection.getWeight().value;
            deltaSum += delta;
        }
        TransferFunction transferFunction = neuron.getTransferFunction();
        double netInput = neuron.getNetInput();
        double f1 = transferFunction.getDerivative(netInput);
        double delta = f1 * deltaSum;
        return delta;
    }
}

