package com.joptimizer.solvers;

import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix2D;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.CholeskyDecomposition;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

/* loaded from: input_file:com/joptimizer/solvers/KKTSolver.class */
public abstract class KKTSolver {
    protected RealMatrix H;
    protected RealMatrix A;
    protected RealMatrix AT;
    protected RealVector g;
    protected RealVector h;
    protected double toleranceKKT;
    protected boolean checkKKTSolutionAccuracy;
    private DoubleFactory2D F2 = DoubleFactory2D.dense;
    private Log log = LogFactory.getLog(getClass().getName());

    public abstract double[][] solve() throws Exception;

    public void setHMatrix(double[][] dArr) {
        this.H = new Array2DRowRealMatrix(dArr, false);
    }

    public void setAMatrix(double[][] dArr) {
        if (dArr == null || dArr.length <= 0) {
            return;
        }
        this.A = new Array2DRowRealMatrix(dArr, false);
    }

    public void setATMatrix(double[][] dArr) {
        if (dArr == null || dArr.length <= 0) {
            return;
        }
        this.AT = new Array2DRowRealMatrix(dArr, false);
    }

    public void setGVector(double[] dArr) {
        this.g = new ArrayRealVector(dArr);
    }

    public void setHVector(double[] dArr) {
        if (dArr == null || dArr.length <= 0) {
            return;
        }
        this.h = new ArrayRealVector(dArr);
    }

    public void setToleranceKKT(double d) {
        this.toleranceKKT = d;
    }

    public void setCheckKKTSolutionAccuracy(boolean z) {
        this.checkKKTSolutionAccuracy = z;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[][] solveFullKKT(KKTSolver kKTSolver) throws Exception {
        this.log.debug("solveFullKKT");
        RealMatrix add = this.H.add(this.AT.multiply(this.A));
        try {
            new CholeskyDecomposition(add).getSolver().getInverse();
            kKTSolver.setHMatrix(add.getData());
            kKTSolver.setAMatrix(this.A.getData());
            kKTSolver.setATMatrix(this.AT.getData());
            kKTSolver.setGVector(this.g.toArray());
            if (this.h != null) {
                kKTSolver.setGVector(this.g.add(this.AT.operate(MatrixUtils.createRealIdentityMatrix(this.A.getRowDimension()).operate(this.h))).toArray());
                kKTSolver.setHVector(this.h.toArray());
            }
            return kKTSolver.solve();
        } catch (Exception e) {
            throw new Exception("singular KKT system");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Type inference failed for: r0v25, types: [cern.colt.matrix.DoubleMatrix2D[], cern.colt.matrix.DoubleMatrix2D[][]] */
    public boolean checkKKTSolutionAccuracy(RealVector realVector, RealVector realVector2) {
        this.F2.make(this.H.getData());
        double norm = this.A != null ? this.h != null ? new Array2DRowRealMatrix(this.F2.compose((DoubleMatrix2D[][]) new DoubleMatrix2D[]{new DoubleMatrix2D[]{this.F2.make(this.H.getData()), this.F2.make(this.AT.getData())}, new DoubleMatrix2D[]{this.F2.make(this.A.getData()), null}}).toArray()).operate(realVector.append(realVector2)).add(this.g.append(this.h)).getNorm() : this.H.operate(realVector).add(this.AT.operate(realVector2)).add(this.g).getNorm() : this.H.operate(realVector).add(this.g).getNorm();
        this.log.debug("KKT solution error: " + norm);
        return norm < this.toleranceKKT;
    }
}
