package edu.uci.jforests.eval.ranking;

import edu.uci.jforests.eval.EvaluationMetric;
import edu.uci.jforests.eval.ranking.RankingEvaluationMetric;
import edu.uci.jforests.eval.ranking.URiskAwareEval;
import edu.uci.jforests.sample.RankingSample;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.CDF_Normal;
import edu.uci.jforests.util.MathUtil;
import java.util.Arrays;

/* loaded from: input_file:edu/uci/jforests/eval/ranking/TRiskAwareSAROEval.class */
public class TRiskAwareSAROEval extends URiskAwareEval {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/uci/jforests/eval/ranking/TRiskAwareSAROEval$SAROSwapScorer.class */
    class SAROSwapScorer extends URiskAwareEval.URiskSwapScorer {
        double currPairedSTD;
        double c;
        double baselineMean;

        public SAROSwapScorer(double[] dArr, int[] iArr, int i, int[][] iArr2, double d, RankingEvaluationMetric.SwapScorer swapScorer) {
            super(dArr, iArr, i, iArr2, d, swapScorer);
            this.currPairedSTD = 0.0d;
            this.c = iArr.length - 1.0d;
        }

        @Override // edu.uci.jforests.eval.ranking.URiskAwareEval.URiskSwapScorer, edu.uci.jforests.eval.ranking.RankingEvaluationMetric.SwapScorer
        public double getDelta(int i, int i2, int i3, int i4, int i5) {
            double delta = this.parentSwap.getDelta(i, i2, i3, i4, i5);
            double d = this.modelEval[i];
            double d2 = this.baselineEval[i];
            double d3 = this.targets[i2];
            double d4 = this.targets[i4];
            double normp = (1.0d - CDF_Normal.normp((d - d2) / this.currPairedSTD)) * this.alpha;
            return d <= d2 ? (d3 <= d4 || i3 >= i5) ? d2 > d + delta ? (1.0d + normp) * delta : (normp * (d2 - d)) + delta : (1.0d + normp) * delta : (d3 <= d4 || i3 >= i5) ? delta : d2 > d - Math.abs(delta) ? (normp * (d - d2)) - ((1.0d + normp) * Math.abs(delta)) : delta;
        }

        @Override // edu.uci.jforests.eval.ranking.URiskAwareEval.URiskSwapScorer, edu.uci.jforests.eval.ranking.RankingEvaluationMetric.SwapScorer
        public void setCurrentIterationEvaluation(int i, double[] dArr) {
            super.setCurrentIterationEvaluation(i, dArr);
            if (i != 0) {
                MathUtil.getAvg(dArr);
                System.err.println("Iteration " + i + " NDCG=" + Arrays.toString(dArr));
            } else {
                this.currPairedSTD = Math.sqrt(TRiskAwareSAROEval.getEstimates(this.baselineEval, this.modelEval, this.alpha)[1]);
                System.err.println("Iteration 0 Paired STD=" + this.currPairedSTD);
                this.baselineMean = MathUtil.getAvg(this.baselineEval);
                System.err.println("Iteration 0 NDCG=" + Arrays.toString(dArr));
            }
        }
    }

    public TRiskAwareSAROEval(EvaluationMetric evaluationMetric, double d) {
        super(evaluationMetric, d);
    }

    public static double[] getEstimates(double[] dArr, double[] dArr2, double d) {
        double length = dArr.length;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < length; i++) {
            double d4 = dArr2[i] > dArr[i] ? dArr2[i] - dArr[i] : (1.0d + d) * (dArr2[i] - dArr[i]);
            d2 += d4;
            d3 += d4 * d4;
        }
        double d5 = d2 / length;
        double d6 = d2 * d2;
        return new double[]{d5, d3 == d6 ? 0.0d : (d3 - (d6 / length)) / (length - 1.0d)};
    }

    public static double T_measure(double[] dArr, double[] dArr2, double d) {
        double length = dArr.length;
        double[] estimates = getEstimates(dArr, dArr2, d);
        if (estimates[1] == 0.0d) {
            return 0.0d;
        }
        return Math.sqrt(length / estimates[1]) * estimates[0];
    }

    @Override // edu.uci.jforests.eval.ranking.URiskAwareEval, edu.uci.jforests.eval.ranking.RankingEvaluationMetric, edu.uci.jforests.eval.EvaluationMetric
    public double measure(double[] dArr, Sample sample) throws Exception {
        RankingSample rankingSample = (RankingSample) sample;
        if (!$assertionsDisabled && rankingSample.queryBoundaries.length - 1 != rankingSample.numQueries) {
            throw new AssertionError();
        }
        return T_measure(((RankingEvaluationMetric) this.parent).measureByQuery(computeNaturalOrderScores(dArr.length, rankingSample.queryBoundaries), sample), ((RankingEvaluationMetric) this.parent).measureByQuery(dArr, sample), this.ALPHA);
    }

    @Override // edu.uci.jforests.eval.ranking.URiskAwareEval, edu.uci.jforests.eval.ranking.RankingEvaluationMetric
    public RankingEvaluationMetric.SwapScorer getSwapScorer(double[] dArr, int[] iArr, int i, int[][] iArr2) throws Exception {
        return new SAROSwapScorer(dArr, iArr, i, iArr2, this.ALPHA, ((RankingEvaluationMetric) this.parent).getSwapScorer(dArr, iArr, i, iArr2));
    }

    static {
        $assertionsDisabled = !TRiskAwareSAROEval.class.desiredAssertionStatus();
    }
}
