package org.ddogleg.optimization.impl;

import org.ejml.alg.dense.mult.VectorVectorMult;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;
import org.ejml.ops.MatrixFeatures;
import org.ejml.ops.NormOps;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/ddogleg/optimization/impl/TestDoglegStepF.class */
public class TestDoglegStepF {
    double cauchyRadius = 0.5d;
    double gaussRadius = 10.0d;
    double combinedRadius = 0.9d;
    DenseMatrix64F J = new DenseMatrix64F(3, 2, true, new double[]{1.0d, 0.5d, 2.0d, Math.sqrt(2.0d), -2.0d, 4.0d});
    DenseMatrix64F x = new DenseMatrix64F(2, 1, true, new double[]{0.5d, 1.5d});
    DenseMatrix64F residuals = new DenseMatrix64F(3, 1, true, new double[]{-1.0d, -2.0d, -3.0d});
    DenseMatrix64F gradient = new DenseMatrix64F(2, 1);

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/ddogleg/optimization/impl/TestDoglegStepF$WrappedDog.class */
    public static class WrappedDog extends DoglegStepF {
        boolean calledCombined = false;
        boolean calledCauchy = false;

        protected WrappedDog() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.ddogleg.optimization.impl.DoglegStepF
        public void cauchyStep(double d, DenseMatrix64F denseMatrix64F) {
            super.cauchyStep(d, denseMatrix64F);
            this.calledCauchy = true;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.ddogleg.optimization.impl.DoglegStepF
        public void combinedStep(double d, DenseMatrix64F denseMatrix64F) {
            super.combinedStep(d, denseMatrix64F);
            this.calledCombined = true;
        }
    }

    public TestDoglegStepF() {
        CommonOps.multTransA(this.J, this.residuals, this.gradient);
    }

    @Test
    public void computeStep_cauchy() {
        CauchyStep cauchyStep = new CauchyStep();
        WrappedDog wrappedDog = new WrappedDog();
        cauchyStep.init(2, 3);
        wrappedDog.init(2, 3);
        cauchyStep.setInputs(this.x, this.residuals, this.J, this.gradient, -1.0d);
        wrappedDog.setInputs(this.x, this.residuals, this.J, this.gradient, -1.0d);
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(2, 1);
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(2, 1);
        cauchyStep.computeStep(this.cauchyRadius, denseMatrix64F);
        wrappedDog.computeStep(this.cauchyRadius, denseMatrix64F2);
        Assert.assertTrue(cauchyStep.isMaxStep());
        Assert.assertTrue(wrappedDog.isMaxStep());
        Assert.assertTrue(wrappedDog.calledCauchy);
        Assert.assertTrue(MatrixFeatures.isIdentical(denseMatrix64F, denseMatrix64F2, 1.0E-8d));
    }

    @Test
    public void computeStep_GaussNewton() {
        DoglegStepFtF doglegStepFtF = new DoglegStepFtF();
        doglegStepFtF.init(2, 3);
        doglegStepFtF.setInputs(this.x, this.residuals, this.J, this.gradient, -1.0d);
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(2, 1);
        doglegStepFtF.computeStep(this.gaussRadius, denseMatrix64F);
        Assert.assertFalse(doglegStepFtF.isMaxStep());
        double cost = cost(this.residuals, this.J, denseMatrix64F, 0.0d, 0.0d);
        double cost2 = cost(this.residuals, this.J, denseMatrix64F, 0.01d, 0.0d);
        double cost3 = cost(this.residuals, this.J, denseMatrix64F, -0.01d, 0.0d);
        double cost4 = cost(this.residuals, this.J, denseMatrix64F, 0.0d, 0.01d);
        double cost5 = cost(this.residuals, this.J, denseMatrix64F, 0.0d, -0.01d);
        Assert.assertTrue(cost < cost2);
        Assert.assertTrue(cost < cost3);
        Assert.assertTrue(cost < cost4);
        Assert.assertTrue(cost < cost5);
    }

    @Test
    public void computeStep_Hybrid() {
        WrappedDog wrappedDog = new WrappedDog();
        wrappedDog.init(2, 3);
        wrappedDog.setInputs(this.x, this.residuals, this.J, this.gradient, -1.0d);
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(2, 1);
        wrappedDog.computeStep(this.combinedRadius, denseMatrix64F);
        Assert.assertTrue(wrappedDog.calledCombined);
        Assert.assertTrue(wrappedDog.isMaxStep());
        Assert.assertEquals(this.combinedRadius, NormOps.normF(denseMatrix64F), 1.0E-8d);
    }

    @Test
    public void predict_cauchy() {
        checkPredictedCost(this.cauchyRadius, true, false);
    }

    @Test
    public void predict_GaussNewton() {
        checkPredictedCost(this.gaussRadius, false, false);
    }

    @Test
    public void predict_Hybrid() {
        checkPredictedCost(this.combinedRadius, false, true);
    }

    private void checkPredictedCost(double d, boolean z, boolean z2) {
        double innerProd = VectorVectorMult.innerProd(this.residuals, this.residuals) * 0.5d;
        WrappedDog wrappedDog = new WrappedDog();
        wrappedDog.init(2, 3);
        wrappedDog.setInputs(this.x, this.residuals, this.J, this.gradient, innerProd);
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(2, 1);
        wrappedDog.computeStep(d, denseMatrix64F);
        Assert.assertTrue(wrappedDog.calledCauchy == z);
        Assert.assertTrue(wrappedDog.calledCombined == z2);
        Assert.assertEquals(innerProd - cost(this.residuals, this.J, denseMatrix64F, 0.0d, 0.0d), wrappedDog.predictedReduction(), 1.0E-8d);
    }

    public static double cost(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3, double... dArr) {
        DenseMatrix64F copy = denseMatrix64F3.copy();
        for (int i = 0; i < copy.numRows; i++) {
            double[] dArr2 = copy.data;
            int i2 = i;
            dArr2[i2] = dArr2[i2] + dArr[i];
        }
        DenseMatrix64F denseMatrix64F4 = new DenseMatrix64F(denseMatrix64F2.numCols, denseMatrix64F2.numCols);
        CommonOps.multTransA(denseMatrix64F2, denseMatrix64F2, denseMatrix64F4);
        return (0.5d * VectorVectorMult.innerProd(denseMatrix64F, denseMatrix64F)) + VectorVectorMult.innerProdA(denseMatrix64F, denseMatrix64F2, copy) + (0.5d * VectorVectorMult.innerProdA(copy, denseMatrix64F4, copy));
    }
}
