/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.learning.boosting;

import edu.uci.jforests.dataset.RankingDataset;
import edu.uci.jforests.eval.EvaluationMetric;
import edu.uci.jforests.eval.ranking.NDCGEval;
import edu.uci.jforests.eval.ranking.RankingEvaluationMetric;
import edu.uci.jforests.learning.boosting.GradientBoosting;
import edu.uci.jforests.learning.boosting.GradientBoostingConfig;
import edu.uci.jforests.learning.boosting.LambdaMARTConfig;
import edu.uci.jforests.learning.trees.LeafInstances;
import edu.uci.jforests.learning.trees.Tree;
import edu.uci.jforests.learning.trees.TreeLeafInstances;
import edu.uci.jforests.learning.trees.regression.RegressionTree;
import edu.uci.jforests.sample.RankingSample;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.ArraysUtil;
import edu.uci.jforests.util.ConfigHolder;
import edu.uci.jforests.util.ScoreBasedComparator;
import edu.uci.jforests.util.concurrency.BlockingThreadPoolExecutor;
import edu.uci.jforests.util.concurrency.TaskCollection;
import edu.uci.jforests.util.concurrency.TaskItem;
import java.util.Arrays;

public class LambdaMART
extends GradientBoosting {
    private TaskCollection<LambdaWorker> workers;
    private RankingEvaluationMetric.SwapScorer swapScorer;
    private double sigmoidParam;
    private double[] sigmoidCache;
    private double minScore;
    private double maxScore;
    private double sigmoidBinWidth;
    protected double[] denomWeights;
    private int[] subLearnerSampleIndicesInTrainSet;

    public LambdaMART() {
        super("LambdaMART");
    }

    public void init(ConfigHolder configHolder, RankingDataset dataset, int maxNumTrainInstances, int maxNumValidInstances, EvaluationMetric evaluationMetric) throws Exception {
        super.init(configHolder, maxNumTrainInstances, maxNumValidInstances, evaluationMetric);
        LambdaMARTConfig lambdaMartConfig = configHolder.getConfig(LambdaMARTConfig.class);
        GradientBoostingConfig gradientBoostingConfig = configHolder.getConfig(GradientBoostingConfig.class);
        int[][] labelCountsPerQuery = NDCGEval.getLabelCountsForQueries(dataset.targets, dataset.queryBoundaries);
        this.swapScorer = ((RankingEvaluationMetric)evaluationMetric).getSwapScorer(dataset.targets, dataset.queryBoundaries, lambdaMartConfig.maxDCGTruncation, labelCountsPerQuery);
        this.sigmoidParam = gradientBoostingConfig.learningRate;
        this.initSigmoidCache(lambdaMartConfig.sigmoidBins, lambdaMartConfig.costFunction);
        this.workers = new TaskCollection();
        int numWorkers = BlockingThreadPoolExecutor.getInstance().getMaximumPoolSize();
        for (int i = 0; i < numWorkers; ++i) {
            this.workers.addTask(new LambdaWorker(dataset.maxDocsPerQuery));
        }
        this.denomWeights = new double[maxNumTrainInstances];
        this.subLearnerSampleIndicesInTrainSet = new int[maxNumTrainInstances];
    }

    private void initSigmoidCache(int sigmoidBins, String costFunction) throws Exception {
        this.minScore = -50.0 / this.sigmoidParam;
        this.maxScore = -this.minScore;
        this.sigmoidCache = new double[sigmoidBins];
        this.sigmoidBinWidth = (this.maxScore - this.minScore) / (double)sigmoidBins;
        if (costFunction.equals("cross-entropy")) {
            for (int i = 0; i < sigmoidBins; ++i) {
                double score = this.minScore + (double)i * this.sigmoidBinWidth;
                this.sigmoidCache[i] = score > 0.0 ? 1.0 - 1.0 / (1.0 + Math.exp(-this.sigmoidParam * score)) : 1.0 / (1.0 + Math.exp(this.sigmoidParam * score));
            }
        } else if (costFunction.equals("fidelity")) {
            for (int i = 0; i < sigmoidBins; ++i) {
                double exp;
                double score = this.minScore + (double)i * this.sigmoidBinWidth;
                if (score > 0.0) {
                    exp = Math.exp(-2.0 * this.sigmoidParam * score);
                    this.sigmoidCache[i] = -this.sigmoidParam / 2.0 * Math.sqrt(exp / Math.pow(1.0 + exp, 3.0));
                    continue;
                }
                exp = Math.exp(this.sigmoidParam * score);
                this.sigmoidCache[i] = -this.sigmoidParam / 2.0 * Math.sqrt(exp / Math.pow(1.0 + exp, 3.0));
            }
        } else {
            throw new Exception("Unknown cost function: " + costFunction);
        }
    }

    @Override
    protected void preprocess() {
        Arrays.fill(this.trainPredictions, 0, this.curTrainSet.size, 0.0);
        if (this.curValidSet != null) {
            Arrays.fill(this.validPredictions, 0, this.curValidSet.size, 0.0);
        }
        RankingEvaluationMetric rankingMetric = (RankingEvaluationMetric)((RankingEvaluationMetric)this.evaluationMetric).getParentMetric();
        double[] nDCG = null;
        try {
            nDCG = ((RankingSample)this.curTrainSet).evaluateByQuery(RankingEvaluationMetric.computeNaturalOrderScores(this.curTrainSet.size, this.swapScorer.getQueryBoundaries()), rankingMetric);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        this.swapScorer.setCurrentIterationEvaluation(0, nDCG);
    }

    @Override
    protected void postProcessScores() {
    }

    protected double getAdjustedOutput(LeafInstances leafInstances) {
        double numerator = 0.0;
        double denomerator = 0.0;
        for (int i = leafInstances.begin; i < leafInstances.end; ++i) {
            int instance = this.subLearnerSampleIndicesInTrainSet[leafInstances.indices[i]];
            numerator += this.residuals[instance];
            denomerator += this.denomWeights[instance];
        }
        return (numerator + 1.4E-45) / (denomerator + 1.4E-45);
    }

    @Override
    protected void adjustOutputs(Tree tree, TreeLeafInstances treeLeafInstances) {
        LeafInstances leafInstances = new LeafInstances();
        for (int l = 0; l < tree.numLeaves; ++l) {
            treeLeafInstances.loadLeafInstances(l, leafInstances);
            double adjustedOutput = this.getAdjustedOutput(leafInstances);
            ((RegressionTree)tree).setLeafOutput(l, adjustedOutput);
        }
    }

    protected void setSubLearnerSampleWeights(RankingSample sample) {
    }

    @Override
    protected Sample getSubLearnerSample() {
        Arrays.fill(this.residuals, 0, this.curTrainSet.size, 0.0);
        Arrays.fill(this.denomWeights, 0, this.curTrainSet.size, 0.0);
        RankingSample trainSample = (RankingSample)this.curTrainSet;
        int chunkSize = 1 + trainSample.numQueries / this.workers.getSize();
        int offset = 0;
        for (int i = 0; i < this.workers.getSize() && offset < trainSample.numQueries; offset += chunkSize, ++i) {
            int endOffset = offset + Math.min(trainSample.numQueries - offset, chunkSize);
            this.workers.getTask(i).init(offset, endOffset);
            BlockingThreadPoolExecutor.getInstance().execute(this.workers.getTask(i));
        }
        BlockingThreadPoolExecutor.getInstance().await();
        trainSample = trainSample.getClone();
        trainSample.targets = this.residuals;
        this.setSubLearnerSampleWeights(trainSample);
        RankingSample zeroFilteredSample = trainSample.getClone();
        RankingSample subLearnerSample = zeroFilteredSample.getRandomSubSample(this.samplingRate, this.rnd);
        for (int i = 0; i < subLearnerSample.size; ++i) {
            this.subLearnerSampleIndicesInTrainSet[i] = zeroFilteredSample.indicesInParentSample[subLearnerSample.indicesInParentSample[i]];
        }
        return subLearnerSample;
    }

    @Override
    protected void onIterationEnd() {
        RankingEvaluationMetric rankingMetric = (RankingEvaluationMetric)((RankingEvaluationMetric)this.evaluationMetric).getParentMetric();
        double[] nDCG = null;
        try {
            nDCG = ((RankingSample)this.curTrainSet).evaluateByQuery(this.trainPredictions, rankingMetric);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        this.swapScorer.setCurrentIterationEvaluation(this.curIteration, nDCG);
        super.onIterationEnd();
    }

    private class LambdaWorker
    extends TaskItem {
        private int[] permutation;
        private int beginIdx;
        private int endIdx;
        private ScoreBasedComparator comparator;

        public LambdaWorker(int maxDocsPerQuery) {
            this.permutation = new int[maxDocsPerQuery];
            this.comparator = new ScoreBasedComparator();
        }

        public void init(int beginIdx, int endIdx) {
            this.beginIdx = beginIdx;
            this.endIdx = endIdx;
            this.comparator.labels = LambdaMART.this.curTrainSet.targets;
        }

        @Override
        public void run() {
            RankingSample trainSet = (RankingSample)LambdaMART.this.curTrainSet;
            double[] targets = trainSet.targets;
            this.comparator.scores = LambdaMART.this.trainPredictions;
            try {
                for (int query = this.beginIdx; query < this.endIdx; ++query) {
                    int begin = trainSet.queryBoundaries[query];
                    int numDocuments = trainSet.queryBoundaries[query + 1] - begin;
                    this.comparator.offset = begin;
                    for (int d = 0; d < numDocuments; ++d) {
                        this.permutation[d] = d;
                    }
                    ArraysUtil.insertionSort(this.permutation, numDocuments, this.comparator);
                    for (int i = 0; i < numDocuments; ++i) {
                        int betterIdx = this.permutation[i];
                        if (!(targets[begin + betterIdx] > 0.0)) continue;
                        for (int j = 0; j < numDocuments; ++j) {
                            int worseIdx;
                            if (i == j || !(targets[begin + betterIdx] > targets[begin + (worseIdx = this.permutation[j])])) continue;
                            double scoreDiff = LambdaMART.this.trainPredictions[begin + betterIdx] - LambdaMART.this.trainPredictions[begin + worseIdx];
                            double rho = scoreDiff <= LambdaMART.this.minScore ? LambdaMART.this.sigmoidCache[0] : (scoreDiff >= LambdaMART.this.maxScore ? LambdaMART.this.sigmoidCache[LambdaMART.this.sigmoidCache.length - 1] : LambdaMART.this.sigmoidCache[(int)((scoreDiff - LambdaMART.this.minScore) / LambdaMART.this.sigmoidBinWidth)]);
                            double pairWeight = Math.abs(LambdaMART.this.swapScorer.getDelta(trainSet.queryIndices[query], begin + betterIdx, i, begin + worseIdx, j));
                            int n = begin + betterIdx;
                            LambdaMART.this.residuals[n] = LambdaMART.this.residuals[n] + rho * pairWeight;
                            int n2 = begin + worseIdx;
                            LambdaMART.this.residuals[n2] = LambdaMART.this.residuals[n2] - rho * pairWeight;
                            double deltaWeight = rho * (1.0 - rho) * pairWeight;
                            int n3 = begin + betterIdx;
                            LambdaMART.this.denomWeights[n3] = LambdaMART.this.denomWeights[n3] + deltaWeight;
                            int n4 = begin + worseIdx;
                            LambdaMART.this.denomWeights[n4] = LambdaMART.this.denomWeights[n4] + deltaWeight;
                        }
                    }
                }
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
}

