/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.propagation;

import java.util.Random;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.training.propagation.GradientWorkerOwner;
import org.encog.util.EngineArray;
import org.encog.util.concurrency.EngineTask;

public class GradientWorker
implements EngineTask {
    protected Random dropoutRandomSource = new Random();
    private final FlatNetwork network;
    private final ErrorCalculation errorCalculation = new ErrorCalculation();
    private final double[] actual;
    private final double[] layerDelta;
    private final int[] layerCounts;
    private final int[] layerFeedCounts;
    private final int[] layerIndex;
    private final int[] weightIndex;
    private final double[] layerOutput;
    private final double[] layerSums;
    private final double[] gradients;
    private final double[] weights;
    private final MLDataPair pair;
    private final MLDataSet training;
    private final int low;
    private final int high;
    private final GradientWorkerOwner owner;
    private double[] flatSpot;
    private final ErrorFunction errorFunction;
    private double[] layerDropoutRates;

    public GradientWorker(FlatNetwork theNetwork, GradientWorkerOwner theOwner, MLDataSet theTraining, int theLow, int theHigh, double[] flatSpot, ErrorFunction ef) {
        this.network = theNetwork;
        this.training = theTraining;
        this.low = theLow;
        this.high = theHigh;
        this.owner = theOwner;
        this.flatSpot = flatSpot;
        this.errorFunction = ef;
        this.layerDelta = new double[this.network.getLayerOutput().length];
        this.gradients = new double[this.network.getWeights().length];
        this.actual = new double[this.network.getOutputCount()];
        this.weights = this.network.getWeights();
        this.layerIndex = this.network.getLayerIndex();
        this.layerCounts = this.network.getLayerCounts();
        this.layerDropoutRates = this.network.getLayerDropoutRates();
        this.weightIndex = this.network.getWeightIndex();
        this.layerOutput = this.network.getLayerOutput();
        this.layerSums = this.network.getLayerSums();
        this.layerFeedCounts = this.network.getLayerFeedCounts();
        this.pair = BasicMLDataPair.createPair(this.network.getInputCount(), this.network.getOutputCount());
    }

    public FlatNetwork getNetwork() {
        return this.network;
    }

    public double[] getWeights() {
        return this.weights;
    }

    public void process(MLDataPair pair) {
        this.network.compute(pair.getInputArray(), this.actual);
        this.errorCalculation.updateError(this.actual, pair.getIdealArray(), pair.getSignificance());
        this.errorFunction.calculateError(this.network.getActivationFunctions()[0], this.layerSums, this.layerOutput, pair.getIdeal().getData(), this.actual, this.layerDelta, this.flatSpot[0], pair.getSignificance());
        if (this.owner.getL1() > 1.0E-13 || this.owner.getL1() > 1.0E-13) {
            double[] lp = new double[2];
            this.calculateRegularizationPenalty(lp);
            int i = 0;
            while (i < this.actual.length) {
                double p = lp[0] * this.owner.getL1() + lp[1] * this.owner.getL2();
                int n = i++;
                this.layerDelta[n] = this.layerDelta[n] + p;
            }
        }
        int i = this.network.getBeginTraining();
        while (i < this.network.getEndTraining()) {
            this.processLevel(i);
            ++i;
        }
    }

    private void processLevel(int currentLevel) {
        int fromLayerIndex = this.layerIndex[currentLevel + 1];
        int toLayerIndex = this.layerIndex[currentLevel];
        int fromLayerSize = this.layerCounts[currentLevel + 1];
        int toLayerSize = this.layerFeedCounts[currentLevel];
        double dropoutRate = 0.0;
        if (this.layerDropoutRates.length > currentLevel && this.layerDropoutRates[currentLevel] != 0.0) {
            dropoutRate = this.layerDropoutRates[currentLevel];
        }
        int index = this.weightIndex[currentLevel];
        ActivationFunction activation = this.network.getActivationFunctions()[currentLevel];
        double currentFlatSpot = this.flatSpot[currentLevel + 1];
        double[] layerDelta = this.layerDelta;
        double[] weights = this.weights;
        double[] gradients = this.gradients;
        double[] layerOutput = this.layerOutput;
        double[] layerSums = this.layerSums;
        int yi = fromLayerIndex;
        int y = 0;
        while (y < fromLayerSize) {
            double output = layerOutput[yi];
            double sum = 0.0;
            int wi = index + y;
            int loopEnd = toLayerIndex + toLayerSize;
            if (dropoutRate == 0.0 || this.dropoutRandomSource.nextDouble() > dropoutRate) {
                int xi = toLayerIndex;
                while (xi < loopEnd) {
                    int n = wi;
                    gradients[n] = gradients[n] + output * layerDelta[xi];
                    sum += weights[wi] * layerDelta[xi];
                    ++xi;
                    wi += fromLayerSize;
                }
                layerDelta[yi] = sum * (activation.derivativeFunction(layerSums[yi], layerOutput[yi]) + currentFlatSpot);
            } else {
                layerDelta[yi] = 0.0;
            }
            ++yi;
            ++y;
        }
    }

    @Override
    public final void run() {
        try {
            this.errorCalculation.reset();
            int i = this.low;
            while (i <= this.high) {
                this.training.getRecord(i, this.pair);
                this.process(this.pair);
                ++i;
            }
            double error = this.errorCalculation.calculate();
            this.owner.report(this.gradients, error, null);
            EngineArray.fill(this.gradients, 0.0);
        }
        catch (Throwable ex) {
            this.owner.report(null, 0.0, ex);
        }
    }

    public final void run(int index) {
        this.training.getRecord(index, this.pair);
        this.process(this.pair);
        this.owner.report(this.gradients, 0.0, null);
        EngineArray.fill(this.gradients, 0.0);
    }

    public ErrorCalculation getErrorCalculation() {
        return this.errorCalculation;
    }

    public double[] getGradients() {
        return this.gradients;
    }

    public void calculateRegularizationPenalty(double[] l) {
        int i = 0;
        while (i < this.network.getLayerCounts().length - 1) {
            this.layerRegularizationPenalty(i, l);
            ++i;
        }
    }

    public void layerRegularizationPenalty(int fromLayer, double[] l) {
        int fromCount = this.network.getLayerTotalNeuronCount(fromLayer);
        int toCount = this.network.getLayerNeuronCount(fromLayer + 1);
        int fromNeuron = 0;
        while (fromNeuron < fromCount) {
            int toNeuron = 0;
            while (toNeuron < toCount) {
                double w = this.network.getWeight(fromLayer, fromNeuron, toNeuron);
                l[0] = l[0] + Math.abs(w);
                l[1] = l[1] + w * w;
                ++toNeuron;
            }
            ++fromNeuron;
        }
    }
}

