/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.eval.ranking;

import edu.uci.jforests.eval.EvaluationMetric;
import edu.uci.jforests.eval.ranking.RankingEvaluationMetric;
import edu.uci.jforests.eval.ranking.URiskAwareEval;
import edu.uci.jforests.sample.RankingSample;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.CDF_Normal;
import edu.uci.jforests.util.MathUtil;
import java.util.Arrays;

public class TRiskAwareSAROEval
extends URiskAwareEval {
    public TRiskAwareSAROEval(EvaluationMetric _parent, double alpha) {
        super(_parent, alpha);
    }

    public static double[] getEstimates(double[] baselinePerQuery, double[] perQuery, double ALPHA) {
        double c = baselinePerQuery.length;
        double sum = 0.0;
        double SSQR = 0.0;
        double d_i = 0.0;
        int i = 0;
        while ((double)i < c) {
            d_i = perQuery[i] > baselinePerQuery[i] ? perQuery[i] - baselinePerQuery[i] : (1.0 + ALPHA) * (perQuery[i] - baselinePerQuery[i]);
            sum += d_i;
            SSQR += d_i * d_i;
            ++i;
        }
        double URisk = sum / c;
        double SQRS = sum * sum;
        double pairedVar = SSQR == SQRS ? 0.0 : (SSQR - SQRS / c) / (c - 1.0);
        return new double[]{URisk, pairedVar};
    }

    public static double T_measure(double[] baselinePerQuery, double[] perQuery, double ALPHA) {
        double c = baselinePerQuery.length;
        double[] params = TRiskAwareSAROEval.getEstimates(baselinePerQuery, perQuery, ALPHA);
        return params[1] == 0.0 ? 0.0 : Math.sqrt(c / params[1]) * params[0];
    }

    @Override
    public double measure(double[] predictions, Sample sample) throws Exception {
        RankingSample rankingSample = (RankingSample)sample;
        assert (rankingSample.queryBoundaries.length - 1 == rankingSample.numQueries);
        double[] naturalOrder = TRiskAwareSAROEval.computeNaturalOrderScores(predictions.length, rankingSample.queryBoundaries);
        double[] baselinePerQuery = ((RankingEvaluationMetric)this.parent).measureByQuery(naturalOrder, sample);
        double[] perQuery = ((RankingEvaluationMetric)this.parent).measureByQuery(predictions, sample);
        return TRiskAwareSAROEval.T_measure(baselinePerQuery, perQuery, this.ALPHA);
    }

    @Override
    public RankingEvaluationMetric.SwapScorer getSwapScorer(double[] targets, int[] boundaries, int trunc, int[][] labelCounts) throws Exception {
        RankingEvaluationMetric.SwapScorer parentMeasure = ((RankingEvaluationMetric)this.parent).getSwapScorer(targets, boundaries, trunc, labelCounts);
        return new SAROSwapScorer(targets, boundaries, trunc, labelCounts, this.ALPHA, parentMeasure);
    }

    class SAROSwapScorer
    extends URiskAwareEval.URiskSwapScorer {
        double currPairedSTD;
        double c;
        double baselineMean;

        public SAROSwapScorer(double[] targets, int[] boundaries, int trunc, int[][] labelCounts, double _alpha, RankingEvaluationMetric.SwapScorer _parent) {
            super(targets, boundaries, trunc, labelCounts, _alpha, _parent);
            this.currPairedSTD = 0.0;
            this.c = (double)boundaries.length - 1.0;
        }

        @Override
        public double getDelta(int queryIndex, int betterIdx, int rank_i, int worseIdx, int rank_j) {
            double delta_M = this.parentSwap.getDelta(queryIndex, betterIdx, rank_i, worseIdx, rank_j);
            double M_m = this.modelEval[queryIndex];
            double M_b = this.baselineEval[queryIndex];
            double rel_i = this.targets[betterIdx];
            double rel_j = this.targets[worseIdx];
            double d_i = M_m - M_b;
            double TRisk = d_i / this.currPairedSTD;
            double beta = (1.0 - CDF_Normal.normp(TRisk)) * this.alpha;
            double delta_T = M_m <= M_b ? (rel_i > rel_j && rank_i < rank_j ? (1.0 + beta) * delta_M : (M_b > M_m + delta_M ? (1.0 + beta) * delta_M : beta * (M_b - M_m) + delta_M)) : (rel_i > rel_j && rank_i < rank_j ? (M_b > M_m - Math.abs(delta_M) ? beta * (M_m - M_b) - (1.0 + beta) * Math.abs(delta_M) : delta_M) : delta_M);
            return delta_T;
        }

        @Override
        public void setCurrentIterationEvaluation(int iteration, double[] nDCG) {
            super.setCurrentIterationEvaluation(iteration, nDCG);
            if (iteration == 0) {
                double[] params = TRiskAwareSAROEval.getEstimates(this.baselineEval, this.modelEval, this.alpha);
                this.currPairedSTD = Math.sqrt(params[1]);
                System.err.println("Iteration 0 Paired STD=" + this.currPairedSTD);
                this.baselineMean = MathUtil.getAvg(this.baselineEval);
                System.err.println("Iteration 0 NDCG=" + Arrays.toString(nDCG));
            } else {
                double modelMean = MathUtil.getAvg(nDCG);
                System.err.println("Iteration " + iteration + " NDCG=" + Arrays.toString(nDCG));
            }
        }
    }
}

