/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.eval.ranking;

import edu.uci.jforests.eval.ranking.RankingEvaluationMetric;
import edu.uci.jforests.sample.RankingSample;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.ArraysUtil;
import edu.uci.jforests.util.ScoreBasedComparator;
import edu.uci.jforests.util.concurrency.BlockingThreadPoolExecutor;
import edu.uci.jforests.util.concurrency.TaskCollection;
import edu.uci.jforests.util.concurrency.TaskItem;

public class MAPEval
extends RankingEvaluationMetric {
    private TaskCollection<MAPWorker> mapWorkers;
    private int maxDocsPerQuery;

    public MAPEval(int maxDocsPerQuery) throws Exception {
        super(true);
        this.maxDocsPerQuery = maxDocsPerQuery;
        int numWorkers = BlockingThreadPoolExecutor.getInstance().getMaximumPoolSize();
        this.mapWorkers = new TaskCollection();
        for (int i = 0; i < numWorkers; ++i) {
            this.mapWorkers.addTask(new MAPWorker());
        }
    }

    public double[] getMAP(double[] predictions, Sample sample, ScoreBasedComparator.TieBreaker tieBreaker) throws Exception {
        RankingSample rankingSample = (RankingSample)sample;
        int chunkSize = 1 + rankingSample.numQueries / this.mapWorkers.getSize();
        int offset = 0;
        int workerCount = 0;
        for (int i = 0; i < this.mapWorkers.getSize() && offset < rankingSample.numQueries; offset += chunkSize, ++i) {
            int endOffset = offset + Math.min(rankingSample.numQueries - offset, chunkSize);
            MAPWorker worker = this.mapWorkers.getTask(i);
            ++workerCount;
            worker.init(rankingSample, predictions, offset, endOffset, tieBreaker);
            BlockingThreadPoolExecutor.getInstance().execute(worker);
        }
        BlockingThreadPoolExecutor.getInstance().await();
        double[] result = new double[rankingSample.numQueries];
        for (int i = 0; i < workerCount; ++i) {
            result[i] = this.mapWorkers.getTask(i).getResult();
        }
        return result;
    }

    @Override
    public double[] measureByQuery(double[] predictions, Sample sample) throws Exception {
        return this.getMAP(predictions, sample, ScoreBasedComparator.TieBreaker.ReverseLabels);
    }

    @Override
    public RankingEvaluationMetric.SwapScorer getSwapScorer(double[] targets, int[] boundaries, int trunc, int[][] labelCounts) throws Exception {
        throw new UnsupportedOperationException("MAP does not yet support SwapScoring for LambdaMART");
    }

    private class MAPWorker
    extends TaskItem {
        private int[] permutation;
        private RankingSample sample;
        private int beginIdx;
        private int endIdx;
        private double result;
        private ScoreBasedComparator comparator;

        public MAPWorker() {
            this.permutation = new int[MAPEval.this.maxDocsPerQuery];
            this.comparator = new ScoreBasedComparator();
        }

        public void init(RankingSample sample, double[] scores, int beginIdx, int endIdx, ScoreBasedComparator.TieBreaker tieBreaker) {
            this.sample = sample;
            this.beginIdx = beginIdx;
            this.endIdx = endIdx;
            this.comparator.labels = sample.targets;
            this.comparator.scores = scores;
            this.comparator.tieBreaker = tieBreaker;
            this.result = 0.0;
        }

        public double getResult() {
            return this.result;
        }

        @Override
        public void run() {
            for (int q = this.beginIdx; q < this.endIdx; ++q) {
                int begin = this.sample.queryBoundaries[q];
                int numDocs = this.sample.queryBoundaries[q + 1] - begin;
                this.comparator.offset = begin;
                for (int d = 0; d < numDocs; ++d) {
                    this.permutation[d] = d;
                }
                ArraysUtil.sort(this.permutation, numDocs, this.comparator);
                try {
                    int numRelevant = 0;
                    double avgPrecision = 0.0;
                    for (int pos = 0; pos < numDocs; ++pos) {
                        if (!(this.sample.targets[begin + this.permutation[pos]] > 0.0)) continue;
                        avgPrecision += (double)(++numRelevant) / (double)(pos + 1);
                    }
                    if (numRelevant <= 0) continue;
                    this.result += avgPrecision / (double)numRelevant;
                    continue;
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
    }
}

