/*
 * Decompiled with CFR 0.152.
 */
package jsat.math.optimization.stochastic;

import java.util.Arrays;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.optimization.stochastic.GradientUpdater;

public class Rprop
implements GradientUpdater {
    private double eta_pos = 1.2;
    private double eta_neg = 0.5;
    private double eta_start = 0.1;
    private double eta_max = 50.0;
    private double eta_min = 1.0E-6;
    private double[] prev_w;
    private double[] prev_grad;
    private double[] cur_eta;
    private double prev_grad_bias;
    private double cur_eta_bias;
    private double prev_bias;

    public Rprop() {
    }

    public Rprop(Rprop toCopy) {
        if (toCopy.prev_grad != null) {
            this.prev_grad = Arrays.copyOf(toCopy.prev_grad, toCopy.prev_grad.length);
        }
        if (toCopy.cur_eta != null) {
            this.cur_eta = Arrays.copyOf(toCopy.cur_eta, toCopy.cur_eta.length);
        }
        if (toCopy.prev_w != null) {
            this.prev_w = Arrays.copyOf(toCopy.prev_w, toCopy.prev_w.length);
        }
        this.prev_grad_bias = toCopy.prev_grad_bias;
        this.cur_eta_bias = toCopy.cur_eta_bias;
        this.prev_bias = toCopy.prev_bias;
    }

    @Override
    public void update(Vec w, Vec grad, double eta) {
        this.update(w, grad, eta, 0.0, 0.0);
    }

    @Override
    public double update(Vec w, Vec grad, double eta, double bias, double biasGrad) {
        for (IndexValue iv : grad) {
            double eta_i;
            int i = iv.getIndex();
            double g_i = iv.getValue();
            double g_prev = this.prev_grad[i];
            double w_i = w.get(i);
            this.prev_grad[i] = g_i;
            double sign_g_i = Math.signum(g_i);
            double sign_g_prev = Math.signum(g_prev);
            if (sign_g_i == 0.0 || sign_g_prev == 0.0) {
                eta_i = this.cur_eta[i];
                w.increment(i, -sign_g_i * eta_i * eta);
            } else if (sign_g_i == sign_g_prev) {
                eta_i = this.cur_eta[i] = Math.min(this.cur_eta[i] * this.eta_pos, this.eta_max);
                w.increment(i, -sign_g_i * eta_i * eta);
            } else {
                eta_i = this.cur_eta[i] = Math.max(this.cur_eta[i] * this.eta_neg, this.eta_min);
                w.increment(i, -this.prev_w[i] * eta_i * eta);
                this.prev_grad[i] = 0.0;
            }
            this.prev_w[i] = w_i;
        }
        if (bias != 0.0 && biasGrad != 0.0) {
            double toRet;
            double g_i = biasGrad;
            double g_prev = this.prev_grad_bias;
            double w_i = bias;
            this.prev_grad_bias = g_i;
            double sign_g_i = Math.signum(g_i);
            double sign_g_prev = Math.signum(g_prev);
            if (sign_g_i == 0.0 || sign_g_prev == 0.0) {
                double eta_i = this.cur_eta_bias;
                toRet = sign_g_i * eta_i;
            } else if (sign_g_i == sign_g_prev) {
                double eta_i = this.cur_eta_bias = Math.min(this.cur_eta_bias * this.eta_pos, this.eta_max);
                toRet = sign_g_i * eta_i;
            } else {
                double eta_i = this.cur_eta_bias = Math.max(this.cur_eta_bias * this.eta_neg, this.eta_min);
                this.prev_grad_bias = 0.0;
                toRet = -this.prev_bias * eta_i;
            }
            this.prev_bias = w_i;
            return toRet * eta;
        }
        return 0.0;
    }

    @Override
    public void setup(int d) {
        this.prev_grad = new double[d];
        this.cur_eta = new double[d];
        Arrays.fill(this.cur_eta, this.eta_start);
        this.prev_w = new double[d];
        this.cur_eta_bias = this.eta_start;
        this.prev_grad_bias = 0.0;
        this.prev_bias = 0.0;
    }

    @Override
    public Rprop clone() {
        return new Rprop(this);
    }
}

