/*
 * Decompiled with CFR 0.152.
 */
package org.ddogleg.optimization.impl;

import org.ddogleg.optimization.impl.TrustRegionStep;
import org.ejml.alg.dense.mult.VectorVectorMult;
import org.ejml.data.DenseMatrix64F;
import org.ejml.factory.LinearSolverFactory;
import org.ejml.interfaces.linsol.LinearSolver;
import org.ejml.ops.CommonOps;
import org.ejml.ops.NormOps;

public class DoglegStepFtF
implements TrustRegionStep {
    private LinearSolver<DenseMatrix64F> pinv;
    private DenseMatrix64F B = new DenseMatrix64F(1, 1);
    private DenseMatrix64F gradient;
    private DenseMatrix64F gradientNeg = new DenseMatrix64F(1, 1);
    private double predicted;
    private boolean maxStep;
    protected DenseMatrix64F stepCauchy = new DenseMatrix64F(1, 1);
    private double distanceCauchy;
    protected DenseMatrix64F stepGN = new DenseMatrix64F(1, 1);
    private double distanceGN;
    private double gBg;
    private double gnorm;

    public DoglegStepFtF(LinearSolver<DenseMatrix64F> pinv) {
        this.pinv = pinv;
    }

    public DoglegStepFtF() {
        this(LinearSolverFactory.leastSquaresQrPivot(true, false));
    }

    @Override
    public void init(int numParam, int numFunctions) {
        this.B.reshape(numParam, numParam);
        this.stepCauchy.reshape(numParam, 1);
        this.stepGN.reshape(numParam, 1);
        this.gradientNeg.reshape(numParam, 1);
    }

    @Override
    public void setInputs(DenseMatrix64F x, DenseMatrix64F residuals, DenseMatrix64F J, DenseMatrix64F gradient, double fx) {
        this.gradient = gradient;
        CommonOps.scale(-1.0, gradient, this.gradientNeg);
        CommonOps.multInner(J, this.B);
        this.gBg = VectorVectorMult.innerProdA(gradient, this.B, gradient);
        this.gnorm = NormOps.normF(gradient);
        this.distanceCauchy = this.gBg == 0.0 ? 0.0 : this.gnorm * this.gnorm / this.gBg;
        if (!this.pinv.setA(this.B)) {
            throw new RuntimeException("pinv failed?!?");
        }
        this.pinv.solve(this.gradientNeg, this.stepGN);
        this.distanceGN = NormOps.normF(this.stepGN);
    }

    @Override
    public void computeStep(double regionRadius, DenseMatrix64F step) {
        if (this.distanceGN <= regionRadius) {
            step.set(this.stepGN);
            this.maxStep = this.distanceGN == regionRadius;
            this.predicted = -0.5 * VectorVectorMult.innerProd(this.stepGN, this.gradient);
        } else if (this.distanceCauchy * this.gnorm >= regionRadius) {
            this.cauchyStep(regionRadius, step);
        } else {
            this.combinedStep(regionRadius, step);
            this.maxStep = true;
        }
    }

    protected void cauchyStep(double regionRadius, DenseMatrix64F step) {
        double dist = this.distanceCauchy;
        double normRadius = regionRadius / this.gnorm;
        if (dist >= normRadius) {
            this.maxStep = true;
            dist = normRadius;
        } else {
            this.maxStep = false;
        }
        CommonOps.scale(-dist, this.gradient, step);
        this.predicted = this.predictCauchy(dist);
    }

    protected void combinedStep(double regionRadius, DenseMatrix64F step) {
        CommonOps.scale(-this.distanceCauchy, this.gradient, this.stepCauchy);
        double beta = DoglegStepFtF.combinedStep(this.stepCauchy, this.stepGN, regionRadius, step);
        double dotGandGN = VectorVectorMult.innerProd(this.stepGN, this.gradient);
        double oneMb = 1.0 - beta;
        double left = -0.5 * this.distanceCauchy * this.distanceCauchy * oneMb * oneMb * this.gBg;
        double middle = -this.distanceCauchy * oneMb * (beta - 1.0) * this.gnorm * this.gnorm;
        double right = (beta * beta / 2.0 - beta) * dotGandGN;
        this.predicted = left + middle + right;
    }

    protected static double combinedStep(DenseMatrix64F stepCauchy, DenseMatrix64F stepGN, double regionRadius, DenseMatrix64F step) {
        double c = 0.0;
        for (int i = 0; i < stepCauchy.numRows; ++i) {
            c += stepCauchy.data[i] * (stepGN.data[i] - stepCauchy.data[i]);
        }
        double bma2 = 0.0;
        double a2 = 0.0;
        for (int i = 0; i < stepCauchy.numRows; ++i) {
            double a = stepCauchy.data[i];
            double d = stepGN.data[i] - a;
            bma2 += d * d;
            a2 += a * a;
        }
        double r2 = regionRadius * regionRadius;
        double beta = c <= 0.0 ? (-c + Math.sqrt(c * c + bma2 * (r2 - a2))) / bma2 : (r2 - a2) / (c + Math.sqrt(c * c + bma2 * (r2 - a2)));
        step.zero();
        for (int i = 0; i < stepCauchy.numRows; ++i) {
            step.data[i] = stepCauchy.data[i] + beta * (stepGN.data[i] - stepCauchy.data[i]);
        }
        return beta;
    }

    private double predictCauchy(double dist) {
        return dist * this.gnorm * this.gnorm - 0.5 * dist * dist * this.gBg;
    }

    @Override
    public double predictedReduction() {
        return this.predicted;
    }

    @Override
    public boolean isMaxStep() {
        return this.maxStep;
    }
}

