/*
 * Decompiled with CFR 0.152.
 */
package org.joone.engine.extenders;

import org.joone.engine.RpropParameters;
import org.joone.engine.extenders.DeltaRuleExtender;
import org.joone.log.ILogger;
import org.joone.log.LoggerFactory;

public class RpropExtender
extends DeltaRuleExtender {
    private static final ILogger log = LoggerFactory.getLogger(RpropExtender.class);
    protected double[][] theDeltas;
    protected double[][] thePreviousGradients;
    protected RpropParameters theRpropParameters;
    protected double[][] theSummedGradients;

    public void reinit() {
        if (this.getLearner().getMonitor().getLearningRate() != 1.0) {
            log.warn("RPROP learning rate should be equal to 1.");
        }
        if (this.getLearner().getLayer() != null) {
            this.thePreviousGradients = new double[this.getLearner().getLayer().getRows()][1];
            this.theSummedGradients = new double[this.thePreviousGradients.length][1];
            this.theDeltas = new double[this.thePreviousGradients.length][1];
        } else if (this.getLearner().getSynapse() != null) {
            int myRows = this.getLearner().getSynapse().getInputDimension();
            int myCols = this.getLearner().getSynapse().getOutputDimension();
            this.thePreviousGradients = new double[myRows][myCols];
            this.theSummedGradients = new double[myRows][myCols];
            this.theDeltas = new double[myRows][myCols];
        }
        for (int i = 0; i < this.theDeltas.length; ++i) {
            for (int j = 0; j < this.theDeltas[0].length; ++j) {
                this.theDeltas[i][j] = this.getParameters().getInitialDelta(i, j);
            }
        }
    }

    @Override
    public double getDelta(double[] currentGradientOuts, int j, double aPreviousDelta) {
        double myDelta = 0.0;
        double[] dArray = this.theSummedGradients[j];
        dArray[0] = dArray[0] - aPreviousDelta;
        if (this.getLearner().getUpdateWeightExtender().storeWeightsBiases()) {
            if (this.thePreviousGradients[j][0] * this.theSummedGradients[j][0] > 0.0) {
                this.theDeltas[j][0] = Math.min(this.theDeltas[j][0] * this.getParameters().getEtaInc(), this.getParameters().getMaxDelta());
                myDelta = -1.0 * this.sign(this.theSummedGradients[j][0]) * this.theDeltas[j][0];
                this.thePreviousGradients[j][0] = this.theSummedGradients[j][0];
            } else if (this.thePreviousGradients[j][0] * this.theSummedGradients[j][0] < 0.0) {
                this.theDeltas[j][0] = Math.max(this.theDeltas[j][0] * this.getParameters().getEtaDec(), this.getParameters().getMinDelta());
                myDelta = -1.0 * this.getLearner().getLayer().getBias().delta[j][0];
                this.thePreviousGradients[j][0] = 0.0;
            } else {
                myDelta = -1.0 * this.sign(this.theSummedGradients[j][0]) * this.theDeltas[j][0];
                this.thePreviousGradients[j][0] = this.theSummedGradients[j][0];
            }
            this.theSummedGradients[j][0] = 0.0;
        }
        return myDelta;
    }

    @Override
    public double getDelta(double[] currentInps, int j, double[] currentPattern, int k, double aPreviousDelta) {
        double myDelta = 0.0;
        double[] dArray = this.theSummedGradients[j];
        int n = k;
        dArray[n] = dArray[n] - aPreviousDelta;
        if (this.getLearner().getUpdateWeightExtender().storeWeightsBiases()) {
            if (this.thePreviousGradients[j][k] * this.theSummedGradients[j][k] > 0.0) {
                this.theDeltas[j][k] = Math.min(this.theDeltas[j][k] * this.getParameters().getEtaInc(), this.getParameters().getMaxDelta());
                myDelta = -1.0 * this.sign(this.theSummedGradients[j][k]) * this.theDeltas[j][k];
                this.thePreviousGradients[j][k] = this.theSummedGradients[j][k];
            } else if (this.thePreviousGradients[j][k] * this.theSummedGradients[j][k] < 0.0) {
                this.theDeltas[j][k] = Math.max(this.theDeltas[j][k] * this.getParameters().getEtaDec(), this.getParameters().getMinDelta());
                myDelta = -1.0 * this.getLearner().getSynapse().getWeights().delta[j][k];
                this.thePreviousGradients[j][k] = 0.0;
            } else {
                myDelta = -1.0 * this.sign(this.theSummedGradients[j][k]) * this.theDeltas[j][k];
                this.thePreviousGradients[j][k] = this.theSummedGradients[j][k];
            }
            this.theSummedGradients[j][k] = 0.0;
        }
        return myDelta;
    }

    @Override
    public void postBiasUpdate(double[] currentGradientOuts) {
    }

    @Override
    public void postWeightUpdate(double[] currentPattern, double[] currentInps) {
    }

    @Override
    public void preBiasUpdate(double[] currentGradientOuts) {
        if (this.theDeltas == null || this.theDeltas.length != this.getLearner().getLayer().getRows()) {
            this.reinit();
        }
    }

    @Override
    public void preWeightUpdate(double[] currentPattern, double[] currentInps) {
        if (this.theDeltas == null || this.theDeltas.length != this.getLearner().getSynapse().getInputDimension() || this.theDeltas[0].length != this.getLearner().getSynapse().getOutputDimension()) {
            this.reinit();
        }
    }

    public RpropParameters getParameters() {
        if (this.theRpropParameters == null) {
            this.theRpropParameters = new RpropParameters();
        }
        return this.theRpropParameters;
    }

    public void setParameters(RpropParameters aParameters) {
        this.theRpropParameters = aParameters;
    }

    protected double sign(double d) {
        if (d > 0.0) {
            return 1.0;
        }
        if (d < 0.0) {
            return -1.0;
        }
        return 0.0;
    }
}

