/*
 * Decompiled with CFR 0.152.
 */
package org.ddogleg.optimization.impl;

import org.ddogleg.optimization.functions.CoupledJacobian;
import org.ddogleg.optimization.impl.LevenbergDenseBase;
import org.ejml.UtilEjml;
import org.ejml.alg.dense.linsol.LinearSolverSafe;
import org.ejml.alg.dense.mult.MatrixMultProduct;
import org.ejml.alg.dense.mult.VectorVectorMult;
import org.ejml.data.DenseMatrix64F;
import org.ejml.factory.LinearSolverFactory;
import org.ejml.interfaces.linsol.LinearSolver;
import org.ejml.ops.CommonOps;

public class LevenbergDampened
extends LevenbergDenseBase {
    protected LinearSolver<DenseMatrix64F> solver;

    public LevenbergDampened(double initialDampParam) {
        super(initialDampParam);
    }

    @Override
    protected void computeJacobian(DenseMatrix64F residuals, DenseMatrix64F gradient) {
        this.function.computeJacobian(this.jacobianVals.data);
        MatrixMultProduct.inner_reorder_upper(this.jacobianVals, this.B);
        CommonOps.multTransA(this.jacobianVals, residuals, gradient);
        CommonOps.extractDiag(this.B, this.Bdiag);
    }

    @Override
    protected boolean computeStep(double lambda, DenseMatrix64F gradientNegative, DenseMatrix64F step) {
        for (int i = 0; i < this.N; ++i) {
            int index = this.B.getIndex(i, i);
            this.B.data[index] = this.Bdiag.data[i] + lambda;
        }
        if (this.solver.setA(this.B) && this.solver.quality() > UtilEjml.EPS) {
            this.solver.solve(gradientNegative, step);
            return true;
        }
        return false;
    }

    @Override
    public void setFunction(CoupledJacobian function) {
        super.setFunction(function);
        this.solver = LinearSolverFactory.symmPosDef(this.N);
        if (this.solver.modifiesB()) {
            this.solver = new LinearSolverSafe<DenseMatrix64F>(this.solver);
        }
    }

    @Override
    protected double predictedReduction(DenseMatrix64F param, DenseMatrix64F gradientNegative, double mu) {
        double p_dot_p = VectorVectorMult.innerProd(param, param);
        double p_dot_g = VectorVectorMult.innerProd(param, gradientNegative);
        return 0.5 * (mu * p_dot_p + p_dot_g);
    }
}

