/*
 * Decompiled with CFR 0.152.
 */
package jsat.math.optimization;

import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.FunctionVec;
import jsat.math.optimization.LineSearch;

public class BacktrackingArmijoLineSearch
implements LineSearch {
    private double rho;
    private double c1;

    public BacktrackingArmijoLineSearch() {
        this(0.5, 0.1);
    }

    public BacktrackingArmijoLineSearch(double rho, double c1) {
        if (!(rho > 0.0) || !(rho < 1.0)) {
            throw new IllegalArgumentException("rho must be in (0,1), not " + rho);
        }
        this.rho = rho;
        this.setC1(c1);
    }

    public void setC1(double c1) {
        if (c1 <= 0.0 || c1 >= 0.5) {
            throw new IllegalArgumentException("c1 must be in (0, 1/2) not " + c1);
        }
        this.c1 = c1;
    }

    public double getC1() {
        return this.c1;
    }

    @Override
    public double lineSearch(double alpha_max, Vec x_k, Vec x_grad, Vec p_k, Function f, FunctionVec fp, double f_x, double gradP, Vec x_alpha_pk, double[] fxApRet, Vec grad_x_alpha_pk, boolean parallel) {
        if (Double.isNaN(f_x)) {
            f_x = f.f(x_k, parallel);
        }
        if (Double.isNaN(gradP)) {
            gradP = x_grad.dot(p_k);
        }
        double alpha = alpha_max;
        if (x_alpha_pk == null) {
            x_alpha_pk = x_k.clone();
        } else {
            x_k.copyTo(x_alpha_pk);
        }
        x_alpha_pk.mutableAdd(alpha, p_k);
        double f_xap = f.f(x_alpha_pk, parallel);
        if (fxApRet != null) {
            fxApRet[0] = f_xap;
        }
        double oldAlpha = 0.0;
        double oldF_xap = f_x;
        while (f_xap > f_x + this.c1 * alpha * gradP) {
            double tooSmall = 0.1 * alpha;
            double tooLarge = 0.9 * alpha;
            if (alpha == alpha_max) {
                double alphaCandidate = -gradP * oldAlpha * oldAlpha / (2.0 * (f_xap - f_x - gradP * oldAlpha));
                oldAlpha = alpha;
                alpha = alphaCandidate < tooSmall || alphaCandidate > tooLarge || Double.isNaN(alphaCandidate) ? this.rho * oldAlpha : alphaCandidate;
            } else {
                double g = f_xap - f_x - gradP * alpha;
                double h = oldF_xap - f_x - gradP * oldAlpha;
                double a0Sqrd = oldAlpha * oldAlpha;
                double a1Sqrd = alpha * alpha;
                double a = a0Sqrd * g - a1Sqrd * h;
                double b = -a0Sqrd * oldAlpha * g + a1Sqrd * alpha * h;
                double alphaCandidate = (-(b /= a0Sqrd * a1Sqrd * (alpha - oldAlpha)) + Math.sqrt(b * b - 3.0 * (a /= a0Sqrd * a1Sqrd * (alpha - oldAlpha)) * gradP)) / (3.0 * a);
                oldAlpha = alpha;
                alpha = alphaCandidate < tooSmall || alphaCandidate > tooLarge || Double.isNaN(alphaCandidate) ? this.rho * oldAlpha : alphaCandidate;
            }
            if (alpha < 1.0E-20) {
                return oldAlpha;
            }
            x_alpha_pk.mutableSubtract(oldAlpha - alpha, p_k);
            oldF_xap = f_xap;
            f_xap = f.f(x_alpha_pk, parallel);
            if (fxApRet == null) continue;
            fxApRet[0] = f_xap;
        }
        return alpha;
    }

    @Override
    public boolean updatesGrad() {
        return false;
    }

    @Override
    public BacktrackingArmijoLineSearch clone() {
        return new BacktrackingArmijoLineSearch(this.rho, this.c1);
    }
}

