package com.joptimizer.optimizers;

import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.colt.matrix.linalg.SeqBlas;
import cern.jet.math.Functions;
import cern.jet.math.Mult;
import com.joptimizer.functions.FunctionsUtils;
import com.joptimizer.solvers.BasicKKTSolver;
import com.joptimizer.solvers.KKTSolver;
import com.joptimizer.util.Utils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/* loaded from: input_file:com/joptimizer/optimizers/PrimalDualMethod.class */
public class PrimalDualMethod extends OptimizationRequestHandler {
    private KKTSolver kktSolver;
    private Algebra ALG = Algebra.DEFAULT;
    private DoubleFactory1D F1 = DoubleFactory1D.dense;
    private DoubleFactory2D F2 = DoubleFactory2D.dense;
    private Log log = LogFactory.getLog(getClass().getName());

    @Override // com.joptimizer.optimizers.OptimizationRequestHandler
    public int optimize() throws Exception {
        this.log.info("optimize");
        long currentTimeMillis = System.currentTimeMillis();
        OptimizationResponse optimizationResponse = new OptimizationResponse();
        DoubleMatrix1D initialPoint = getInitialPoint();
        if (initialPoint == null) {
            DoubleMatrix1D notFeasibleInitialPoint = getNotFeasibleInitialPoint();
            if (notFeasibleInitialPoint != null) {
                double sqrt = Math.sqrt(this.ALG.norm2(rPri(notFeasibleInitialPoint)));
                DoubleMatrix1D fi = getFi(notFeasibleInitialPoint);
                double d = fi.get(Utils.getMaxIndex(fi.toArray()));
                if (this.log.isDebugEnabled()) {
                    this.log.debug("rPriX0NFNorm :  " + sqrt);
                    this.log.debug("X0NF         :  " + ArrayUtils.toString(notFeasibleInitialPoint.toArray()));
                    this.log.debug("fiX0NF       :  " + ArrayUtils.toString(fi.toArray()));
                }
                if (d < 0.0d && sqrt <= getToleranceFeas()) {
                    this.log.debug("the provided initial point is already feasible");
                    initialPoint = notFeasibleInitialPoint;
                }
            }
            if (initialPoint == null) {
                initialPoint = new BasicPhaseIPDM(this).findFeasibleInitialPoint();
            }
        }
        DoubleMatrix1D fi2 = getFi(initialPoint);
        int maxIndex = Utils.getMaxIndex(fi2.toArray());
        double d2 = fi2.get(maxIndex);
        double sqrt2 = Math.sqrt(this.ALG.norm2(rPri(initialPoint)));
        if (d2 >= 0.0d || sqrt2 > getToleranceFeas()) {
            this.log.debug("rPriX0Norm  : " + sqrt2);
            this.log.debug("ineqX0      : " + ArrayUtils.toString(fi2.toArray()));
            this.log.debug("max ineq index: " + maxIndex);
            this.log.debug("max ineq value: " + d2);
            throw new Exception("initial point must be strictly feasible");
        }
        DoubleMatrix1D make = getA() != null ? this.F1.make(getA().rows()) : this.F1.make(0);
        DoubleMatrix1D initialLagrangian = getInitialLagrangian();
        if (initialLagrangian != null) {
            for (int i = 0; i < initialLagrangian.size(); i++) {
                if (initialLagrangian.get(i) <= 0.0d) {
                    throw new IllegalArgumentException("initial lagrangian must be strictly > 0");
                }
            }
        } else {
            initialLagrangian = this.F1.make(getFi().length, Math.min(1.0d, getDim() / getFi().length));
        }
        if (this.log.isDebugEnabled()) {
            this.log.debug("X0:  " + ArrayUtils.toString(initialPoint.toArray()));
            this.log.debug("V0:  " + ArrayUtils.toString(make.toArray()));
            this.log.debug("L0:  " + ArrayUtils.toString(initialLagrangian.toArray()));
            this.log.debug("toleranceFeas:  " + getToleranceFeas());
            this.log.debug("tolerance    :  " + getTolerance());
        }
        DoubleMatrix1D doubleMatrix1D = initialPoint;
        DoubleMatrix1D doubleMatrix1D2 = make;
        DoubleMatrix1D doubleMatrix1D3 = initialLagrangian;
        double d3 = Double.NaN;
        double d4 = Double.NaN;
        double d5 = Double.NaN;
        int i2 = 0;
        while (true) {
            i2++;
            if (i2 == getMaxIteration() + 1) {
                optimizationResponse.setReturnCode(1);
                this.log.warn("Max iterations limit reached");
                break;
            }
            double f0 = getF0(doubleMatrix1D);
            if (this.log.isDebugEnabled()) {
                this.log.debug("iteration: " + i2);
                this.log.debug("X=" + ArrayUtils.toString(doubleMatrix1D.toArray()));
                this.log.debug("L=" + ArrayUtils.toString(doubleMatrix1D3.toArray()));
                this.log.debug("V=" + ArrayUtils.toString(doubleMatrix1D2.toArray()));
                this.log.debug("f0(X)=" + f0);
            }
            DoubleMatrix1D gradF0 = getGradF0(doubleMatrix1D);
            DoubleMatrix1D fi3 = getFi(doubleMatrix1D);
            DoubleMatrix2D gradFi = getGradFi(doubleMatrix1D);
            DoubleMatrix2D[] hessFi = getHessFi(doubleMatrix1D);
            double surrogateDualityGap = getSurrogateDualityGap(fi3, doubleMatrix1D3);
            double mu = (getMu() * getFi().length) / surrogateDualityGap;
            this.log.debug("t:  " + mu);
            DoubleMatrix1D rPri = rPri(doubleMatrix1D);
            DoubleMatrix1D rCent = rCent(fi3, doubleMatrix1D3, mu);
            DoubleMatrix1D rDual = rDual(gradFi, gradF0, doubleMatrix1D3, doubleMatrix1D2);
            double sqrt3 = Math.sqrt(this.ALG.norm2(rPri));
            double sqrt4 = Math.sqrt(this.ALG.norm2(rCent));
            double sqrt5 = Math.sqrt(this.ALG.norm2(rDual));
            double sqrt6 = Math.sqrt(Math.pow(sqrt3, 2.0d) + Math.pow(sqrt4, 2.0d) + Math.pow(sqrt5, 2.0d));
            this.log.debug("rPri  norm: " + sqrt3);
            this.log.debug("rCent norm: " + sqrt4);
            this.log.debug("rDual norm: " + sqrt5);
            this.log.debug("surrDG    : " + surrogateDualityGap);
            if (checkCustomExitConditions(doubleMatrix1D)) {
                optimizationResponse.setReturnCode(0);
                break;
            }
            if (sqrt3 <= getToleranceFeas() && sqrt5 <= getToleranceFeas() && surrogateDualityGap <= getTolerance()) {
                optimizationResponse.setReturnCode(0);
                break;
            }
            if (isCheckProgressConditions()) {
                if (Double.isNaN(d3) || Double.isNaN(d4) || Double.isNaN(d5) || ((d3 > sqrt3 || sqrt3 < getToleranceFeas()) && (d4 > sqrt5 || sqrt5 < getToleranceFeas()))) {
                    d3 = sqrt3;
                    d4 = sqrt5;
                    d5 = surrogateDualityGap;
                }
            }
            DoubleMatrix2D hessF0 = getHessF0(doubleMatrix1D);
            for (int i3 = 0; i3 < getFi().length; i3++) {
                if (hessFi[i3] != FunctionsUtils.ZEROES_MATRIX_PLACEHOLDER) {
                    hessF0.assign(hessFi[i3].copy().assign(Mult.mult(doubleMatrix1D3.get(i3))), Functions.plus);
                }
            }
            DoubleMatrix2D make2 = this.F2.make(getDim(), getDim());
            for (int i4 = 0; i4 < getFi().length; i4++) {
                double quick = (-doubleMatrix1D3.getQuick(i4)) / fi3.getQuick(i4);
                DoubleMatrix1D viewRow = gradFi.viewRow(i4);
                SeqBlas.seqBlas.dger(quick, viewRow, viewRow, make2);
            }
            DoubleMatrix2D assign = hessF0.assign(make2, Functions.plus);
            DoubleMatrix1D make3 = this.F1.make(getDim());
            for (int i5 = 0; i5 < getFi().length; i5++) {
                make3.assign(gradFi.viewRow(i5).copy().assign(Mult.div((-mu) * fi3.get(i5))), Functions.plus);
            }
            DoubleMatrix1D assign2 = getAT() == null ? gradF0.copy().assign(make3, Functions.plus) : gradF0.copy().assign(make3, Functions.plus).assign(this.ALG.mult(getAT(), doubleMatrix1D2), Functions.plus);
            if (this.kktSolver == null) {
                this.kktSolver = new BasicKKTSolver();
            }
            if (isCheckKKTSolutionAccuracy()) {
                this.kktSolver.setCheckKKTSolutionAccuracy(true);
                this.kktSolver.setToleranceKKT(getToleranceKKT());
            }
            this.kktSolver.setHMatrix(assign.toArray());
            this.kktSolver.setGVector(assign2.toArray());
            if (getA() != null) {
                this.kktSolver.setAMatrix(getA().toArray());
                this.kktSolver.setATMatrix(getAT().toArray());
                this.kktSolver.setHVector(rPri.toArray());
            }
            double[][] solve = this.kktSolver.solve();
            DoubleMatrix1D make4 = this.F1.make(solve[0]);
            DoubleMatrix1D make5 = solve[1] != null ? this.F1.make(solve[1]) : this.F1.make(0);
            if (this.log.isDebugEnabled()) {
                this.log.debug("stepX: " + ArrayUtils.toString(make4.toArray()));
                this.log.debug("stepV: " + ArrayUtils.toString(make5.toArray()));
            }
            DoubleMatrix2D diagonal = this.F2.diagonal(fi3.copy().assign(Functions.inv));
            DoubleMatrix1D assign3 = this.ALG.mult(diagonal, this.ALG.mult(this.F2.diagonal(doubleMatrix1D3), this.ALG.mult(gradFi, make4))).assign(Mult.mult(-1.0d)).assign(this.ALG.mult(diagonal, rCent), Functions.plus);
            if (this.log.isDebugEnabled()) {
                this.log.debug("stepL: " + ArrayUtils.toString(assign3.toArray()));
            }
            double d6 = Double.MAX_VALUE;
            for (int i6 = 0; i6 < getFi().length; i6++) {
                if (assign3.get(i6) < 0.0d) {
                    d6 = Math.min((-doubleMatrix1D3.get(i6)) / assign3.get(i6), d6);
                }
            }
            double min = 0.99d * Math.min(1.0d, d6);
            DoubleMatrix1D make6 = this.F1.make(doubleMatrix1D.size());
            DoubleMatrix1D make7 = this.F1.make(doubleMatrix1D3.size());
            DoubleMatrix1D make8 = this.F1.make(doubleMatrix1D2.size());
            int i7 = 0;
            boolean z = true;
            while (i7 < 500) {
                i7++;
                make6 = make4.copy().assign(Mult.mult(min)).assign(doubleMatrix1D, Functions.plus);
                DoubleMatrix1D fi4 = getFi(make6);
                z = true;
                for (int i8 = 0; z && i8 < getFi().length; i8++) {
                    z = Double.compare(fi4.get(i8), 0.0d) < 0;
                }
                if (z) {
                    break;
                }
                min = getBeta() * min;
            }
            if (!z) {
                throw new Exception("Optimization failed: impossible to remain within the faesible region");
            }
            this.log.debug("s: " + min);
            double d7 = Double.NaN;
            int i9 = 0;
            while (true) {
                if (i9 < 500) {
                    i9++;
                    make6 = make4.copy().assign(Mult.mult(min)).assign(doubleMatrix1D, Functions.plus);
                    make7 = assign3.copy().assign(Mult.mult(min)).assign(doubleMatrix1D3, Functions.plus);
                    make8 = make5.copy().assign(Mult.mult(min)).assign(doubleMatrix1D2, Functions.plus);
                    if (isInDomainF0(make6)) {
                        double sqrt7 = Math.sqrt(this.ALG.norm2(rPri(make6)) + this.ALG.norm2(rCent(getFi(make6), make7, mu)) + this.ALG.norm2(rDual(getGradFi(make6), getGradF0(make6), make7, make8)));
                        if (sqrt7 > (1.0d - (getAlpha() * min)) * sqrt6) {
                            if (!Double.isNaN(d7) && d7 <= sqrt7) {
                                this.log.warn("No progress achieved in backtracking with norm");
                                break;
                            }
                            d7 = sqrt7;
                        }
                    }
                    min = getBeta() * min;
                }
            }
            doubleMatrix1D = make6;
            doubleMatrix1D2 = make8;
            doubleMatrix1D3 = make7;
        }
        this.log.warn("No progress achieved, exit iterations loop without desired accuracy");
        optimizationResponse.setReturnCode(1);
        this.log.debug("time: " + (System.currentTimeMillis() - currentTimeMillis));
        this.log.debug("sol : " + ArrayUtils.toString(doubleMatrix1D.toArray()));
        this.log.debug("ret code: " + optimizationResponse.getReturnCode());
        optimizationResponse.setSolution(doubleMatrix1D.toArray());
        setOptimizationResponse(optimizationResponse);
        return optimizationResponse.getReturnCode();
    }

    private double getSurrogateDualityGap(DoubleMatrix1D doubleMatrix1D, DoubleMatrix1D doubleMatrix1D2) {
        return -this.ALG.mult(doubleMatrix1D, doubleMatrix1D2);
    }

    private DoubleMatrix1D rDual(DoubleMatrix2D doubleMatrix2D, DoubleMatrix1D doubleMatrix1D, DoubleMatrix1D doubleMatrix1D2, DoubleMatrix1D doubleMatrix1D3) {
        return getA() == null ? doubleMatrix2D.zMult(doubleMatrix1D2, doubleMatrix1D.copy(), 1.0d, 1.0d, true) : getA().zMult(doubleMatrix1D3, doubleMatrix2D.zMult(doubleMatrix1D2, doubleMatrix1D.copy(), 1.0d, 1.0d, true), 1.0d, 1.0d, true);
    }

    private DoubleMatrix1D rCent(DoubleMatrix1D doubleMatrix1D, DoubleMatrix1D doubleMatrix1D2, double d) {
        return this.F2.diagonal(doubleMatrix1D2).zMult(doubleMatrix1D, this.F1.make(getFi().length, 1.0d / d), -1.0d, -1.0d, false);
    }

    public void setKKTSolver(KKTSolver kKTSolver) {
        this.kktSolver = kKTSolver;
    }
}
