/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.propagation;

import java.util.Random;
import org.encog.EncogError;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.mathutil.IntRange;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.error.LinearErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.BatchSize;
import org.encog.neural.networks.training.Train;
import org.encog.neural.networks.training.propagation.GradientWorker;
import org.encog.neural.networks.training.propagation.GradientWorkerOwner;
import org.encog.util.EncogValidate;
import org.encog.util.EngineArray;
import org.encog.util.concurrency.DetermineWorkload;
import org.encog.util.concurrency.EngineConcurrency;
import org.encog.util.concurrency.MultiThreadable;
import org.encog.util.concurrency.TaskGroup;
import org.encog.util.logging.EncogLogging;

public abstract class Propagation
extends BasicTraining
implements Train,
MultiThreadable,
BatchSize,
GradientWorkerOwner {
    protected Random dropoutRandomSource = new Random();
    private double dropoutRate = 0.0;
    private FlatNetwork currentFlatNetwork;
    private int numThreads;
    protected double[] gradients;
    private final double[] lastGradient;
    protected final ContainsFlat network;
    private final MLDataSet indexable;
    private GradientWorker[] workers;
    private double totalError;
    private Throwable reportedException;
    private int iteration;
    private double[] flatSpot;
    private boolean shouldFixFlatSpot;
    private ErrorFunction ef = new LinearErrorFunction();
    private int batchSize = 0;
    private double l1;
    private double l2;
    private boolean finalized = false;

    public Propagation(ContainsFlat network, MLDataSet training) {
        super(TrainingImplementationType.Iterative);
        this.network = network;
        this.currentFlatNetwork = network.getFlat();
        this.setTraining(training);
        this.gradients = new double[this.currentFlatNetwork.getWeights().length];
        this.lastGradient = new double[this.currentFlatNetwork.getWeights().length];
        this.indexable = training;
        this.numThreads = 0;
        this.reportedException = null;
        this.shouldFixFlatSpot = true;
    }

    public void setDroupoutRate(double rate) {
        this.dropoutRate = rate;
    }

    public double getDropoutRate() {
        return this.dropoutRate;
    }

    @Override
    public void finishTraining() {
        this.finishTraining(this.dropoutRate);
    }

    public void finishTraining(double dropoutRate) {
        if (!this.finalized) {
            double[] weights = this.currentFlatNetwork.getWeights();
            if (dropoutRate > 0.0) {
                int i = 0;
                while (i < weights.length) {
                    int n = i++;
                    weights[n] = weights[n] * (1.0 - dropoutRate);
                }
            }
            this.finalized = true;
        }
        super.finishTraining();
    }

    public FlatNetwork getCurrentFlatNetwork() {
        return this.currentFlatNetwork;
    }

    @Override
    public MLMethod getMethod() {
        return this.network;
    }

    @Override
    public void iteration() {
        this.iteration(1);
    }

    public void rollIteration() {
        ++this.iteration;
    }

    private void processPureBatch() {
        this.calculateGradients();
        if (this.currentFlatNetwork.isLimited()) {
            this.learnLimited();
        } else {
            this.learn();
        }
    }

    private void processBatches() {
        if (this.workers == null) {
            this.init();
        }
        if (this.currentFlatNetwork.getHasContext()) {
            this.workers[0].getNetwork().clearContext();
        }
        this.workers[0].getErrorCalculation().reset();
        int lastLearn = 0;
        int i = 0;
        while (i < this.getTraining().size()) {
            this.workers[0].run(i);
            int n = ++lastLearn;
            ++lastLearn;
            if (n >= this.batchSize) {
                if (this.currentFlatNetwork.isLimited()) {
                    this.learnLimited();
                } else {
                    this.learn();
                    lastLearn = 0;
                }
            }
            ++i;
        }
        if (lastLearn > 0) {
            this.learn();
        }
        this.setError(this.workers[0].getErrorCalculation().calculate());
    }

    @Override
    public void iteration(int count) {
        try {
            int i = 0;
            while (i < count) {
                this.preIteration();
                this.rollIteration();
                if (this.batchSize == 0) {
                    this.processPureBatch();
                } else {
                    this.processBatches();
                }
                GradientWorker[] gradientWorkerArray = this.workers;
                int n = this.workers.length;
                int n2 = 0;
                while (n2 < n) {
                    GradientWorker worker = gradientWorkerArray[n2];
                    EngineArray.arrayCopy(this.currentFlatNetwork.getWeights(), 0, worker.getWeights(), 0, this.currentFlatNetwork.getWeights().length);
                    ++n2;
                }
                if (this.currentFlatNetwork.getHasContext()) {
                    this.copyContexts();
                }
                if (this.reportedException != null) {
                    throw new EncogError(this.reportedException);
                }
                this.postIteration();
                EncogLogging.log(1, "Training iteration done, error: " + this.getError());
                ++i;
            }
        }
        catch (ArrayIndexOutOfBoundsException ex) {
            EncogValidate.validateNetworkForTraining(this.network, this.getTraining());
            throw new EncogError(ex);
        }
    }

    @Override
    public void setThreadCount(int numThreads) {
        this.numThreads = numThreads;
    }

    @Override
    public int getThreadCount() {
        return this.numThreads;
    }

    public void fixFlatSpot(boolean b) {
        this.shouldFixFlatSpot = b;
    }

    public void setErrorFunction(ErrorFunction ef) {
        this.ef = ef;
    }

    public void calculateGradients() {
        if (this.workers == null) {
            this.init();
        }
        if (this.currentFlatNetwork.getHasContext()) {
            this.workers[0].getNetwork().clearContext();
        }
        this.totalError = 0.0;
        if (this.workers.length > 1) {
            TaskGroup group = EngineConcurrency.getInstance().createTaskGroup();
            GradientWorker[] gradientWorkerArray = this.workers;
            int n = this.workers.length;
            int n2 = 0;
            while (n2 < n) {
                GradientWorker worker = gradientWorkerArray[n2];
                EngineConcurrency.getInstance().processTask(worker, group);
                ++n2;
            }
            group.waitForComplete();
        } else {
            this.workers[0].run();
        }
        this.setError(this.totalError / (double)this.workers.length);
    }

    private void copyContexts() {
        int i = 0;
        while (i < this.workers.length - 1) {
            double[] src = this.workers[i].getNetwork().getLayerOutput();
            double[] dst = this.workers[i + 1].getNetwork().getLayerOutput();
            EngineArray.arrayCopy(src, dst);
            ++i;
        }
        EngineArray.arrayCopy(this.workers[this.workers.length - 1].getNetwork().getLayerOutput(), this.currentFlatNetwork.getLayerOutput());
    }

    private void init() {
        this.flatSpot = new double[this.currentFlatNetwork.getActivationFunctions().length];
        if (this.shouldFixFlatSpot) {
            int i = 0;
            while (i < this.currentFlatNetwork.getActivationFunctions().length) {
                ActivationFunction af = this.currentFlatNetwork.getActivationFunctions()[i];
                this.flatSpot[i] = af instanceof ActivationSigmoid ? 0.1 : 0.0;
                ++i;
            }
        } else {
            EngineArray.fill(this.flatSpot, 0.0);
        }
        if (this.batchSize != 0) {
            this.numThreads = 1;
        }
        DetermineWorkload determine = new DetermineWorkload(this.numThreads, (int)this.indexable.getRecordCount());
        int actualThreadCount = determine.getThreadCount();
        this.workers = new GradientWorker[actualThreadCount];
        int index = 0;
        for (IntRange r : determine.calculateWorkers()) {
            this.workers[index++] = new GradientWorker(this.currentFlatNetwork.clone(), this, this.indexable.openAdditional(), r.getLow(), r.getHigh(), this.flatSpot, this.ef);
        }
        this.initOthers();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void report(double[] gradients, double error, Throwable ex) {
        Propagation propagation = this;
        synchronized (propagation) {
            if (ex == null) {
                int i = 0;
                while (i < gradients.length) {
                    int n = i;
                    this.gradients[n] = this.gradients[n] + gradients[i];
                    ++i;
                }
                this.totalError += error;
            } else {
                this.reportedException = ex;
            }
        }
    }

    protected void learn() {
        double[] weights = this.currentFlatNetwork.getWeights();
        if (this.dropoutRate > 0.0) {
            int i = 0;
            while (i < this.gradients.length) {
                int n = i;
                weights[n] = weights[n] + this.updateWeight(this.gradients, this.lastGradient, i, this.dropoutRate);
                this.gradients[i] = 0.0;
                ++i;
            }
        } else {
            int i = 0;
            while (i < this.gradients.length) {
                int n = i;
                weights[n] = weights[n] + this.updateWeight(this.gradients, this.lastGradient, i);
                this.gradients[i] = 0.0;
                ++i;
            }
        }
    }

    protected void learnLimited() {
        int i;
        double limit = this.currentFlatNetwork.getConnectionLimit();
        double[] weights = this.currentFlatNetwork.getWeights();
        if (this.dropoutRate > 0.0) {
            i = 0;
            while (i < this.gradients.length) {
                if (Math.abs(weights[i]) < limit) {
                    weights[i] = 0.0;
                } else {
                    int n = i;
                    weights[n] = weights[n] + this.updateWeight(this.gradients, this.lastGradient, i, this.dropoutRate);
                }
                this.gradients[i] = 0.0;
                ++i;
            }
        } else {
            i = 0;
            while (i < this.gradients.length) {
                if (Math.abs(weights[i]) < limit) {
                    weights[i] = 0.0;
                } else {
                    int n = i;
                    weights[n] = weights[n] + this.updateWeight(this.gradients, this.lastGradient, i);
                }
                this.gradients[i] = 0.0;
                ++i;
            }
        }
        i = 0;
        while (i < this.gradients.length) {
            ++i;
        }
    }

    public abstract void initOthers();

    public abstract double updateWeight(double[] var1, double[] var2, int var3);

    public abstract double updateWeight(double[] var1, double[] var2, int var3, double var4);

    public double[] getLastGradient() {
        return this.lastGradient;
    }

    @Override
    public int getBatchSize() {
        return this.batchSize;
    }

    @Override
    public void setBatchSize(int theBatchSize) {
        this.batchSize = theBatchSize;
    }

    @Override
    public double getL1() {
        return this.l1;
    }

    public void setL1(double l1) {
        this.l1 = l1;
    }

    @Override
    public double getL2() {
        return this.l2;
    }

    public void setL2(double l2) {
        this.l2 = l2;
    }
}

