package org.spaceroots.mantissa.optimization;

import java.util.Arrays;
import java.util.Comparator;
import org.spaceroots.mantissa.random.CorrelatedRandomVectorGenerator;
import org.spaceroots.mantissa.random.NotPositiveDefiniteMatrixException;
import org.spaceroots.mantissa.random.RandomVectorGenerator;
import org.spaceroots.mantissa.random.UncorrelatedRandomVectorGenerator;
import org.spaceroots.mantissa.random.UniformRandomGenerator;
import org.spaceroots.mantissa.random.VectorialSampleStatistics;

/* loaded from: input_file:org/spaceroots/mantissa/optimization/DirectSearchOptimizer.class */
public abstract class DirectSearchOptimizer {
    private static Comparator pointCostPairComparator = new Comparator() { // from class: org.spaceroots.mantissa.optimization.DirectSearchOptimizer.1
        @Override // java.util.Comparator
        public int compare(Object obj, Object obj2) {
            if (obj == null) {
                return obj2 == null ? 0 : 1;
            }
            if (obj2 != null && ((PointCostPair) obj).cost >= ((PointCostPair) obj2).cost) {
                return obj == obj2 ? 0 : 1;
            }
            return -1;
        }
    };
    protected PointCostPair[] simplex;
    private CostFunction f;
    private int evaluations;
    private int starts;
    private RandomVectorGenerator generator;
    private PointCostPair[] minima;

    public PointCostPair minimizes(CostFunction costFunction, int i, ConvergenceChecker convergenceChecker, double[] dArr, double[] dArr2) throws CostException, NoConvergenceException {
        buildSimplex(dArr, dArr2);
        setSingleStart();
        return minimizes(costFunction, i, convergenceChecker);
    }

    public PointCostPair minimizes(CostFunction costFunction, int i, ConvergenceChecker convergenceChecker, double[] dArr, double[] dArr2, int i2, int[] iArr) throws CostException, NoConvergenceException {
        buildSimplex(dArr, dArr2);
        double[] dArr3 = new double[dArr.length];
        double[] dArr4 = new double[dArr.length];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr3[i3] = 0.5d * (dArr[i3] + dArr2[i3]);
            dArr4[i3] = 0.5d * Math.abs(dArr[i3] - dArr2[i3]);
        }
        setMultiStart(i2, new UncorrelatedRandomVectorGenerator(dArr3, dArr4, new UniformRandomGenerator(iArr)));
        return minimizes(costFunction, i, convergenceChecker);
    }

    public PointCostPair minimizes(CostFunction costFunction, int i, ConvergenceChecker convergenceChecker, double[][] dArr) throws CostException, NoConvergenceException {
        buildSimplex(dArr);
        setSingleStart();
        return minimizes(costFunction, i, convergenceChecker);
    }

    public PointCostPair minimizes(CostFunction costFunction, int i, ConvergenceChecker convergenceChecker, double[][] dArr, int i2, int[] iArr) throws NotPositiveDefiniteMatrixException, CostException, NoConvergenceException {
        buildSimplex(dArr);
        VectorialSampleStatistics vectorialSampleStatistics = new VectorialSampleStatistics();
        for (double[] dArr2 : dArr) {
            vectorialSampleStatistics.add(dArr2);
        }
        setMultiStart(i2, new CorrelatedRandomVectorGenerator(vectorialSampleStatistics.getMean(), vectorialSampleStatistics.getCovarianceMatrix(null), new UniformRandomGenerator(iArr)));
        return minimizes(costFunction, i, convergenceChecker);
    }

    public PointCostPair minimizes(CostFunction costFunction, int i, ConvergenceChecker convergenceChecker, RandomVectorGenerator randomVectorGenerator) throws CostException, NoConvergenceException {
        buildSimplex(randomVectorGenerator);
        setSingleStart();
        return minimizes(costFunction, i, convergenceChecker);
    }

    public PointCostPair minimizes(CostFunction costFunction, int i, ConvergenceChecker convergenceChecker, RandomVectorGenerator randomVectorGenerator, int i2) throws CostException, NoConvergenceException {
        buildSimplex(randomVectorGenerator);
        setMultiStart(i2, randomVectorGenerator);
        return minimizes(costFunction, i, convergenceChecker);
    }

    private void buildSimplex(double[] dArr, double[] dArr2) {
        int length = dArr.length;
        this.simplex = new PointCostPair[length + 1];
        for (int i = 0; i <= length; i++) {
            double[] dArr3 = new double[length];
            if (i > 0) {
                System.arraycopy(dArr2, 0, dArr3, 0, i);
            }
            if (i < length) {
                System.arraycopy(dArr, i, dArr3, i, length - i);
            }
            this.simplex[i] = new PointCostPair(dArr3, Double.NaN);
        }
    }

    private void buildSimplex(double[][] dArr) {
        int length = dArr.length - 1;
        this.simplex = new PointCostPair[length + 1];
        for (int i = 0; i <= length; i++) {
            this.simplex[i] = new PointCostPair(dArr[i], Double.NaN);
        }
    }

    private void buildSimplex(RandomVectorGenerator randomVectorGenerator) {
        double[] nextVector = randomVectorGenerator.nextVector();
        int length = nextVector.length;
        this.simplex = new PointCostPair[length + 1];
        this.simplex[0] = new PointCostPair(nextVector, Double.NaN);
        for (int i = 1; i <= length; i++) {
            this.simplex[i] = new PointCostPair(randomVectorGenerator.nextVector(), Double.NaN);
        }
    }

    private void setSingleStart() {
        this.starts = 1;
        this.generator = null;
        this.minima = null;
    }

    public void setMultiStart(int i, RandomVectorGenerator randomVectorGenerator) {
        if (i < 2) {
            this.starts = 1;
            this.generator = null;
            this.minima = null;
        } else {
            this.starts = i;
            this.generator = randomVectorGenerator;
            this.minima = null;
        }
    }

    public PointCostPair[] getMinima() {
        return (PointCostPair[]) this.minima.clone();
    }

    private PointCostPair minimizes(CostFunction costFunction, int i, ConvergenceChecker convergenceChecker) throws CostException, NoConvergenceException {
        this.f = costFunction;
        this.minima = new PointCostPair[this.starts];
        for (int i2 = 0; i2 < this.starts; i2++) {
            this.evaluations = 0;
            evaluateSimplex();
            boolean z = true;
            while (z) {
                if (convergenceChecker.converged(this.simplex)) {
                    this.minima[i2] = this.simplex[0];
                    z = false;
                } else if (this.evaluations >= i) {
                    this.minima[i2] = null;
                    z = false;
                } else {
                    iterateSimplex();
                }
            }
            if (i2 < this.starts - 1) {
                buildSimplex(this.generator);
            }
        }
        Arrays.sort(this.minima, pointCostPairComparator);
        if (this.minima[0] == null) {
            throw new NoConvergenceException("none of the {0} start points lead to convergence", new String[]{Integer.toString(this.starts)});
        }
        return this.minima[0];
    }

    protected abstract void iterateSimplex() throws CostException;

    /* JADX INFO: Access modifiers changed from: protected */
    public double evaluateCost(double[] dArr) throws CostException {
        this.evaluations++;
        return this.f.cost(dArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void evaluateSimplex() throws CostException {
        for (int i = 0; i < this.simplex.length; i++) {
            PointCostPair pointCostPair = this.simplex[i];
            if (Double.isNaN(pointCostPair.cost)) {
                this.simplex[i] = new PointCostPair(pointCostPair.point, evaluateCost(pointCostPair.point));
            }
        }
        Arrays.sort(this.simplex, pointCostPairComparator);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void replaceWorstPoint(PointCostPair pointCostPair) {
        int length = this.simplex.length - 1;
        for (int i = 0; i < length; i++) {
            if (this.simplex[i].cost > pointCostPair.cost) {
                PointCostPair pointCostPair2 = this.simplex[i];
                this.simplex[i] = pointCostPair;
                pointCostPair = pointCostPair2;
            }
        }
        this.simplex[length] = pointCostPair;
    }
}
