/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.calibration;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryCalibration;
import jsat.classifiers.calibration.BinaryScoreClassifier;

public class IsotonicCalibration
extends BinaryCalibration {
    private static final long serialVersionUID = -1295979238755262335L;
    private double[] outputs;
    private double[] scores;

    public IsotonicCalibration(BinaryScoreClassifier base, BinaryCalibration.CalibrationMode mode) {
        super(base, mode);
    }

    protected IsotonicCalibration(IsotonicCalibration toCopy) {
        super(toCopy.base.clone(), toCopy.mode);
        if (toCopy.outputs != null) {
            this.outputs = Arrays.copyOf(toCopy.outputs, toCopy.outputs.length);
        }
        if (toCopy.scores != null) {
            this.scores = Arrays.copyOf(toCopy.scores, toCopy.scores.length);
        }
    }

    @Override
    protected void calibrate(boolean[] label, double[] deci, int len) {
        ArrayList<Point> points = new ArrayList<Point>(len);
        for (int i = 0; i < len; ++i) {
            points.add(new Point(deci[i], label[i] ? 1.0 : 0.0));
        }
        Collections.sort(points);
        boolean violators = true;
        while (violators) {
            violators = false;
            for (int i = 0; i < points.size() - 1; ++i) {
                if (!((Point)points.get(i)).nextViolates((Point)points.get(i + 1))) continue;
                violators = true;
                ((Point)points.get(i)).merge((Point)points.remove(i + 1));
                --i;
            }
        }
        this.scores = new double[points.size() * 2];
        this.outputs = new double[points.size() * 2];
        int pos = 0;
        for (Point p : points) {
            this.scores[pos] = p.min;
            this.outputs[pos++] = p.output;
            this.scores[pos] = p.max;
            this.outputs[pos++] = p.output;
        }
    }

    @Override
    public IsotonicCalibration clone() {
        return new IsotonicCalibration(this);
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        double score = this.base.getScore(data);
        CategoricalResults cr = new CategoricalResults(2);
        int indx = Arrays.binarySearch(this.scores, score);
        if (indx < 0) {
            indx = -indx - 1;
        }
        if (indx == this.scores.length) {
            double maxScore = this.scores[this.scores.length - 1];
            if (score > maxScore * 3.0) {
                cr.setProb(1, 1.0);
            } else {
                double p = (maxScore * 3.0 - score) / (maxScore * 2.0) * this.outputs[this.scores.length - 1];
                cr.setProb(0, 1.0 - p);
                cr.setProb(1, p);
            }
        } else if (indx == 0) {
            double minScore = this.scores[0];
            if (score < minScore / 3.0) {
                cr.setProb(0, 1.0);
            } else {
                double p = (minScore - score) / (minScore - minScore / 3.0) * this.outputs[0];
                cr.setProb(0, 1.0 - p);
                cr.setProb(1, p);
            }
        } else {
            double score0 = this.scores[indx - 1];
            double score1 = this.scores[indx];
            if (score0 == score1) {
                cr.setProb(0, 1.0 - this.outputs[indx]);
                cr.setProb(1, this.outputs[indx]);
                return cr;
            }
            double weight = (score1 - score) / (score1 - score0);
            double p = this.outputs[indx - 1] * weight + this.outputs[indx] * (1.0 - weight);
            cr.setProb(0, 1.0 - p);
            cr.setProb(1, p);
        }
        return cr;
    }

    @Override
    public boolean supportsWeightedData() {
        return this.base.supportsWeightedData();
    }

    private static class Point
    implements Comparable<Point> {
        public double weight = 1.0;
        public double score;
        public double output;
        public double min;
        public double max;

        public Point(double score, double output) {
            this.max = this.score = score;
            this.min = this.score;
            this.output = output;
        }

        public void merge(Point next) {
            double newWeight = this.weight + next.weight;
            this.score = (this.weight * this.score + next.weight * next.score) / newWeight;
            this.output = (this.weight * this.output + next.weight * next.output) / newWeight;
            this.weight = newWeight;
            this.min = Math.min(this.min, next.min);
            this.max = Math.max(this.max, next.max);
        }

        public boolean nextViolates(Point next) {
            return this.output >= next.output;
        }

        @Override
        public int compareTo(Point o) {
            return Double.compare(this.score, o.score);
        }
    }
}

