/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.eval.ranking;

import edu.uci.jforests.eval.ranking.RankingEvaluationMetric;
import edu.uci.jforests.sample.RankingSample;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.ArraysUtil;
import edu.uci.jforests.util.ScoreBasedComparator;
import edu.uci.jforests.util.concurrency.BlockingThreadPoolExecutor;
import edu.uci.jforests.util.concurrency.TaskCollection;
import edu.uci.jforests.util.concurrency.TaskItem;
import java.util.Arrays;

public class PrecisionEval
extends RankingEvaluationMetric {
    private TaskCollection<PrecisionWorker> mapWorkers;
    private int maxLevels;
    double relevancyThreshold;
    private int maxDocsPerQuery;

    public PrecisionEval(int maxDocsPerQuery, int maxLevels, double relevancyThreshold) throws Exception {
        super(true);
        this.maxDocsPerQuery = maxDocsPerQuery;
        this.maxLevels = maxLevels;
        this.relevancyThreshold = relevancyThreshold;
        int numWorkers = BlockingThreadPoolExecutor.getInstance().getMaximumPoolSize();
        this.mapWorkers = new TaskCollection();
        for (int i = 0; i < numWorkers; ++i) {
            this.mapWorkers.addTask(new PrecisionWorker());
        }
    }

    public double[] getPrecisions(double[] predictions, Sample sample, ScoreBasedComparator.TieBreaker tieBreaker) throws Exception {
        RankingSample rankingSample = (RankingSample)sample;
        int chunkSize = 1 + rankingSample.numQueries / this.mapWorkers.getSize();
        int offset = 0;
        int workerCount = 0;
        for (int i = 0; i < this.mapWorkers.getSize() && offset < rankingSample.numQueries; offset += chunkSize, ++i) {
            int endOffset = offset + Math.min(rankingSample.numQueries - offset, chunkSize);
            PrecisionWorker worker = this.mapWorkers.getTask(i);
            ++workerCount;
            worker.init(rankingSample, predictions, offset, endOffset, tieBreaker);
            BlockingThreadPoolExecutor.getInstance().execute(worker);
        }
        BlockingThreadPoolExecutor.getInstance().await();
        double[] result = new double[this.maxLevels];
        for (int i = 0; i < workerCount; ++i) {
            double[] localResult = this.mapWorkers.getTask(i).getResult();
            for (int p = 0; p < this.maxLevels; ++p) {
                int n = p;
                result[n] = result[n] + localResult[p];
            }
        }
        int p = 0;
        while (p < this.maxLevels) {
            int n = p++;
            result[n] = result[n] / (double)rankingSample.numQueries;
        }
        return result;
    }

    @Override
    public double measure(double[] predictions, Sample sample) throws Exception {
        return this.getPrecisions(predictions, sample, ScoreBasedComparator.TieBreaker.ReverseLabels)[0];
    }

    @Override
    public RankingEvaluationMetric.SwapScorer getSwapScorer(double[] targets, int[] boundaries, int trunc, int[][] labelCounts) throws Exception {
        throw new UnsupportedOperationException("Precision not supported for LambdaMART yet!");
    }

    @Override
    public double[] measureByQuery(double[] predictions, Sample sample) throws Exception {
        throw new UnsupportedOperationException("Precision not supported yet!");
    }

    private class PrecisionWorker
    extends TaskItem {
        private int[] permutation;
        private RankingSample sample;
        private int beginIdx;
        private int endIdx;
        private double[] result;
        private ScoreBasedComparator comparator;

        public PrecisionWorker() {
            this.permutation = new int[PrecisionEval.this.maxDocsPerQuery];
            this.comparator = new ScoreBasedComparator();
            this.result = new double[PrecisionEval.this.maxLevels];
        }

        public void init(RankingSample sample, double[] scores, int beginIdx, int endIdx, ScoreBasedComparator.TieBreaker tieBreaker) {
            this.sample = sample;
            this.beginIdx = beginIdx;
            this.endIdx = endIdx;
            this.comparator.labels = sample.targets;
            this.comparator.scores = scores;
            this.comparator.tieBreaker = tieBreaker;
            Arrays.fill(this.result, 0.0);
        }

        public double[] getResult() {
            return this.result;
        }

        @Override
        public void run() {
            for (int q = this.beginIdx; q < this.endIdx; ++q) {
                int begin = this.sample.queryBoundaries[q];
                int numDocs = this.sample.queryBoundaries[q + 1] - begin;
                this.comparator.offset = begin;
                for (int d = 0; d < numDocs; ++d) {
                    this.permutation[d] = d;
                }
                ArraysUtil.sort(this.permutation, numDocs, this.comparator);
                try {
                    int numRelevant = 0;
                    for (int pos = 0; pos < Math.min(numDocs, PrecisionEval.this.maxLevels); ++pos) {
                        if (this.sample.targets[begin + this.permutation[pos]] >= PrecisionEval.this.relevancyThreshold) {
                            ++numRelevant;
                        }
                        int n = pos;
                        this.result[n] = this.result[n] + (double)numRelevant / (double)(pos + 1);
                    }
                    continue;
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
    }
}

