/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.exceptions.FailedToFitException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.FakeExecutor;
import jsat.utils.IntSet;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class RANSAC
implements Regressor,
Parameterized {
    private static final long serialVersionUID = -5015748604828907703L;
    private int initialTrainSize;
    private int iterations;
    private double maxPointError;
    private int minResultSize;
    @Parameter.ParameterHolder
    private Regressor baseRegressor;
    private boolean[] consensusSet;
    private double modelError;

    public RANSAC(Regressor baseRegressor, int iterations, int initialTrainSize, int minResultSize, double maxPointError) {
        this.setInitialTrainSize(initialTrainSize);
        this.setIterations(iterations);
        this.setMaxPointError(maxPointError);
        this.setMinResultSize(minResultSize);
        this.baseRegressor = baseRegressor;
    }

    public int getInitialTrainSize() {
        return this.initialTrainSize;
    }

    public void setInitialTrainSize(int initialTrainSize) {
        if (initialTrainSize < 1) {
            throw new RuntimeException("Can not train on an empty data set");
        }
        this.initialTrainSize = initialTrainSize;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setIterations(int iterations) {
        if (iterations < 1) {
            throw new RuntimeException("Must perform a positive number of iterations");
        }
        this.iterations = iterations;
    }

    public double getMaxPointError() {
        return this.maxPointError;
    }

    public void setMaxPointError(double maxPointError) {
        if (maxPointError < 0.0 || Double.isInfinite(maxPointError) || Double.isNaN(maxPointError)) {
            throw new ArithmeticException("The error must be a positive value, not " + maxPointError);
        }
        this.maxPointError = maxPointError;
    }

    public int getMinResultSize() {
        return this.minResultSize;
    }

    public void setMinResultSize(int minResultSize) {
        if (minResultSize < this.getInitialTrainSize()) {
            throw new RuntimeException("The min result size must be larger than the intial train size");
        }
        this.minResultSize = minResultSize;
    }

    @Override
    public double regress(DataPoint data) {
        return this.baseRegressor.regress(data);
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        try {
            ExecutorService threadPool;
            int workSize = this.iterations / SystemInfo.LogicalCores;
            int leftOver = this.iterations % SystemInfo.LogicalCores;
            ArrayList<Future<RANSACWorker>> futures = new ArrayList<Future<RANSACWorker>>(SystemInfo.LogicalCores + 1);
            ExecutorService executorService = threadPool = parallel ? ParallelUtils.CACHED_THREAD_POOL : new FakeExecutor();
            if (leftOver != 0) {
                futures.add(threadPool.submit(new RANSACWorker(this.baseRegressor, leftOver, dataSet)));
            }
            for (int i = 0; i < SystemInfo.LogicalCores; ++i) {
                futures.add(threadPool.submit(new RANSACWorker(this.baseRegressor, workSize, dataSet)));
            }
            PriorityQueue results = new PriorityQueue(SystemInfo.LogicalCores + 1);
            for (Future future : futures) {
                results.add(future.get());
            }
            RANSACWorker bestResult = (RANSACWorker)results.peek();
            this.modelError = bestResult.bestError;
            if (Double.isInfinite(this.modelError)) {
                throw new FailedToFitException("Model could not be fit, inlier set never reach minimum size");
            }
            this.baseRegressor = bestResult.bestModel;
            this.consensusSet = bestResult.bestConsensusSet;
        }
        catch (InterruptedException | ExecutionException ex) {
            Logger.getLogger(RANSAC.class.getName()).log(Level.SEVERE, null, ex);
            throw new FailedToFitException(ex);
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return this.baseRegressor.supportsWeightedData();
    }

    @Override
    public RANSAC clone() {
        RANSAC clone = new RANSAC(this.baseRegressor.clone(), this.iterations, this.initialTrainSize, this.minResultSize, this.maxPointError);
        return clone;
    }

    public Regressor getBaseRegressorClone() {
        return this.baseRegressor.clone();
    }

    public boolean[] getConsensusSet() {
        return Arrays.copyOf(this.consensusSet, this.consensusSet.length);
    }

    public double getModelError() {
        return this.modelError;
    }

    private class RANSACWorker
    implements Callable<RANSACWorker>,
    Comparable<RANSACWorker> {
        int maxIterations;
        RegressionDataSet dataset;
        Random rand;
        Regressor baseModel;
        Regressor bestModel = null;
        boolean[] bestConsensusSet = null;
        double bestError = Double.POSITIVE_INFINITY;

        public RANSACWorker(Regressor baseModel, int maxIterations, RegressionDataSet dataset) {
            this.baseModel = baseModel;
            this.maxIterations = maxIterations;
            this.dataset = dataset;
            this.rand = RandomUtil.getRandom();
        }

        @Override
        public RANSACWorker call() throws Exception {
            this.bestConsensusSet = new boolean[this.dataset.getSampleSize()];
            boolean[] working_set = new boolean[this.dataset.getSampleSize()];
            IntSet maybe_inliers = new IntSet(RANSAC.this.initialTrainSize * 2);
            for (int iter = 0; iter < this.maxIterations; ++iter) {
                int i;
                maybe_inliers.clear();
                Arrays.fill(working_set, false);
                while (maybe_inliers.size() < RANSAC.this.initialTrainSize) {
                    maybe_inliers.add(Integer.valueOf(this.rand.nextInt(working_set.length)));
                }
                int consensusSize = maybe_inliers.size();
                RegressionDataSet subDataSet = new RegressionDataSet(this.dataset.getNumNumericalVars(), this.dataset.getCategories());
                Iterator iterator = maybe_inliers.iterator();
                while (iterator.hasNext()) {
                    i = (Integer)iterator.next();
                    subDataSet.addDataPointPair(this.dataset.getDataPointPair(i));
                    working_set[i] = true;
                }
                Regressor maybeModel = this.baseModel.clone();
                maybeModel.train(subDataSet);
                for (i = 0; i < working_set.length; ++i) {
                    DataPointPair<Double> dpp;
                    double guess;
                    double diff;
                    if (working_set[i] || !((diff = Math.abs((guess = maybeModel.regress((dpp = this.dataset.getDataPointPair(i)).getDataPoint())) - dpp.getPair())) < RANSAC.this.maxPointError)) continue;
                    working_set[i] = true;
                    subDataSet.addDataPointPair(dpp);
                    ++consensusSize;
                }
                if (consensusSize < RANSAC.this.minResultSize) continue;
                maybeModel.train(subDataSet);
                double thisError = 0.0;
                for (int i2 = 0; i2 < working_set.length; ++i2) {
                    if (!working_set[i2]) continue;
                    DataPointPair<Double> dpp = this.dataset.getDataPointPair(i2);
                    double guess = maybeModel.regress(dpp.getDataPoint());
                    double diff = Math.abs(guess - dpp.getPair());
                    thisError += diff;
                }
                if (!(thisError < this.bestError)) continue;
                this.bestError = thisError;
                this.bestModel = maybeModel;
                System.arraycopy(working_set, 0, this.bestConsensusSet, 0, working_set.length);
            }
            return this;
        }

        @Override
        public int compareTo(RANSACWorker o) {
            return Double.compare(this.bestError, o.bestError);
        }
    }
}

