/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.propagation.resilient;

import org.encog.mathutil.EncogMath;
import org.encog.ml.data.MLDataSet;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.neural.networks.training.propagation.resilient.RPROPType;
import org.encog.util.EngineArray;

public class ResilientPropagation
extends Propagation {
    private final double[] updateValues;
    private final double[] lastDelta;
    private final double zeroTolerance;
    private final double maxStep;
    private static RPROPType rpropType = RPROPType.RPROPp;
    private double[] lastWeightChange;
    private double lastError = Double.POSITIVE_INFINITY;
    public static final String LAST_GRADIENTS = "LAST_GRADIENTS";
    public static final String UPDATE_VALUES = "UPDATE_VALUES";
    private static double q = 1.0;

    public ResilientPropagation(ContainsFlat network, MLDataSet training) {
        this(network, training, 0.1, 50.0);
    }

    public ResilientPropagation(ContainsFlat network, MLDataSet training, double initialUpdate, double maxStep) {
        super(network, training);
        this.updateValues = new double[network.getFlat().getWeights().length];
        this.lastDelta = new double[network.getFlat().getWeights().length];
        this.lastWeightChange = new double[network.getFlat().getWeights().length];
        this.zeroTolerance = 1.0E-17;
        this.maxStep = maxStep;
        int i = 0;
        while (i < this.updateValues.length) {
            this.updateValues[i] = initialUpdate;
            this.lastDelta[i] = 0.0;
            ++i;
        }
    }

    @Override
    public boolean canContinue() {
        return true;
    }

    public boolean isValidResume(TrainingContinuation state) {
        if (!state.getContents().containsKey(LAST_GRADIENTS) || !state.getContents().containsKey(UPDATE_VALUES)) {
            return false;
        }
        if (!state.getTrainingType().equals(this.getClass().getSimpleName())) {
            return false;
        }
        double[] d = (double[])state.get(LAST_GRADIENTS);
        return d.length == ((ContainsFlat)this.getMethod()).getFlat().getWeights().length;
    }

    @Override
    public TrainingContinuation pause() {
        TrainingContinuation result = new TrainingContinuation();
        result.setTrainingType(this.getClass().getSimpleName());
        result.set(LAST_GRADIENTS, this.getLastGradient());
        result.set(UPDATE_VALUES, this.getUpdateValues());
        return result;
    }

    @Override
    public void resume(TrainingContinuation state) {
        if (!this.isValidResume(state)) {
            throw new TrainingError("Invalid training resume data length");
        }
        double[] lastGradient = (double[])state.get(LAST_GRADIENTS);
        double[] updateValues = (double[])state.get(UPDATE_VALUES);
        EngineArray.arrayCopy(lastGradient, this.getLastGradient());
        EngineArray.arrayCopy(updateValues, this.getUpdateValues());
    }

    public void setRPROPType(RPROPType t) {
        rpropType = t;
    }

    public RPROPType getRPROPType() {
        return rpropType;
    }

    @Override
    public void initOthers() {
    }

    @Override
    public double updateWeight(double[] gradients, double[] lastGradient, int index) {
        double weightChange = 0.0;
        switch (rpropType) {
            case RPROPp: {
                weightChange = this.updateWeightPlus(gradients, lastGradient, index);
                break;
            }
            case RPROPm: {
                weightChange = this.updateWeightMinus(gradients, lastGradient, index);
                break;
            }
            case iRPROPp: {
                weightChange = this.updateiWeightPlus(gradients, lastGradient, index);
                break;
            }
            case iRPROPm: {
                weightChange = this.updateiWeightMinus(gradients, lastGradient, index);
                break;
            }
            case ARPROP: {
                weightChange = this.updateJacobiWeight(gradients, lastGradient, index);
                break;
            }
            default: {
                throw new TrainingError("Unknown RPROP type: " + (Object)((Object)rpropType));
            }
        }
        this.lastWeightChange[index] = weightChange;
        return weightChange;
    }

    @Override
    public double updateWeight(double[] gradients, double[] lastGradient, int index, double dropoutRate) {
        double weightChange = 0.0;
        if (dropoutRate > 0.0 && this.dropoutRandomSource.nextDouble() < dropoutRate) {
            return 0.0;
        }
        switch (rpropType) {
            case RPROPp: {
                weightChange = this.updateWeightPlus(gradients, lastGradient, index);
                break;
            }
            case RPROPm: {
                weightChange = this.updateWeightMinus(gradients, lastGradient, index);
                break;
            }
            case iRPROPp: {
                weightChange = this.updateiWeightPlus(gradients, lastGradient, index);
                break;
            }
            case iRPROPm: {
                weightChange = this.updateiWeightMinus(gradients, lastGradient, index);
                break;
            }
            case ARPROP: {
                weightChange = this.updateJacobiWeight(gradients, lastGradient, index);
                break;
            }
            default: {
                throw new TrainingError("Unknown RPROP type: " + (Object)((Object)rpropType));
            }
        }
        this.lastWeightChange[index] = weightChange;
        return weightChange;
    }

    public double updateWeightPlus(double[] gradients, double[] lastGradient, int index) {
        int change = EncogMath.sign(gradients[index] * lastGradient[index]);
        double weightChange = 0.0;
        if (change > 0) {
            double delta = this.updateValues[index] * 1.2;
            delta = Math.min(delta, this.maxStep);
            weightChange = (double)EncogMath.sign(gradients[index]) * delta;
            this.updateValues[index] = delta;
            lastGradient[index] = gradients[index];
        } else if (change < 0) {
            double delta = this.updateValues[index] * 0.5;
            this.updateValues[index] = delta = Math.max(delta, 1.0E-6);
            weightChange = -this.lastWeightChange[index];
            lastGradient[index] = 0.0;
        } else if (change == 0) {
            double delta = this.updateValues[index];
            weightChange = (double)EncogMath.sign(gradients[index]) * delta;
            lastGradient[index] = gradients[index];
        }
        return weightChange;
    }

    public double updateWeightMinus(double[] gradients, double[] lastGradient, int index) {
        double delta;
        int change = EncogMath.sign(gradients[index] * lastGradient[index]);
        double weightChange = 0.0;
        if (change > 0) {
            delta = this.lastDelta[index] * 1.2;
            delta = Math.min(delta, this.maxStep);
        } else {
            delta = this.lastDelta[index] * 0.5;
            delta = Math.max(delta, 1.0E-6);
        }
        lastGradient[index] = gradients[index];
        weightChange = (double)EncogMath.sign(gradients[index]) * delta;
        this.lastDelta[index] = delta;
        return weightChange;
    }

    public double updateiWeightPlus(double[] gradients, double[] lastGradient, int index) {
        int change = EncogMath.sign(gradients[index] * lastGradient[index]);
        double weightChange = 0.0;
        if (change > 0) {
            double delta = this.updateValues[index] * 1.2;
            delta = Math.min(delta, this.maxStep);
            weightChange = (double)EncogMath.sign(gradients[index]) * delta;
            this.updateValues[index] = delta;
            lastGradient[index] = gradients[index];
        } else if (change < 0) {
            double delta = this.updateValues[index] * 0.5;
            this.updateValues[index] = delta = Math.max(delta, 1.0E-6);
            if (this.getError() > this.lastError) {
                weightChange = -this.lastWeightChange[index];
            }
            lastGradient[index] = 0.0;
        } else if (change == 0) {
            double delta = this.updateValues[index];
            weightChange = (double)EncogMath.sign(gradients[index]) * delta;
            lastGradient[index] = gradients[index];
        }
        return weightChange;
    }

    public double updateiWeightMinus(double[] gradients, double[] lastGradient, int index) {
        double delta;
        int change = EncogMath.sign(gradients[index] * lastGradient[index]);
        double weightChange = 0.0;
        if (change > 0) {
            delta = this.lastDelta[index] * 1.2;
            delta = Math.min(delta, this.maxStep);
        } else {
            delta = this.lastDelta[index] * 0.5;
            delta = Math.max(delta, 1.0E-6);
            lastGradient[index] = 0.0;
        }
        lastGradient[index] = gradients[index];
        weightChange = (double)EncogMath.sign(gradients[index]) * delta;
        this.lastDelta[index] = delta;
        return weightChange;
    }

    public double updateJacobiWeight(double[] gradients, double[] lastGradient, int index) {
        int change = EncogMath.sign(gradients[index] * lastGradient[index]);
        double weightChange = 0.0;
        double delta = this.updateValues[index];
        if (change > 0) {
            delta = this.updateValues[index] * 1.2;
            delta = Math.min(delta, this.maxStep);
            weightChange = (double)EncogMath.sign(gradients[index]) * delta;
            this.updateValues[index] = delta;
            lastGradient[index] = gradients[index];
        } else if (change < 0) {
            delta = this.updateValues[index] * 0.5;
            this.updateValues[index] = delta = Math.max(delta, 1.0E-6);
            weightChange = -this.lastWeightChange[index];
            lastGradient[index] = 0.0;
        } else if (change == 0) {
            delta = this.updateValues[index];
            weightChange = (double)EncogMath.sign(gradients[index]) * delta;
            lastGradient[index] = gradients[index];
        }
        if (this.getError() > this.lastError) {
            weightChange = 1.0 / (2.0 * q) * delta;
            q += 1.0;
        } else {
            q = 1.0;
        }
        return weightChange;
    }

    @Override
    public void postIteration() {
        super.postIteration();
        this.lastError = this.getError();
    }

    public double[] getUpdateValues() {
        return this.updateValues;
    }
}

