/*
 * Decompiled with CFR 0.152.
 */
package jsat.parameters;

import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.ClassificationModelEvaluation;
import jsat.classifiers.Classifier;
import jsat.distributions.Distribution;
import jsat.parameters.DoubleParameter;
import jsat.parameters.IntParameter;
import jsat.parameters.ModelSearch;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.RegressionModelEvaluation;
import jsat.regression.Regressor;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class RandomSearch
extends ModelSearch {
    private int trials = 25;
    private List<Distribution> searchValues;

    public RandomSearch(Regressor baseRegressor, int folds) {
        super(baseRegressor, folds);
        this.searchValues = new ArrayList<Distribution>();
    }

    public RandomSearch(Classifier baseClassifier, int folds) {
        super(baseClassifier, folds);
        this.searchValues = new ArrayList<Distribution>();
    }

    public RandomSearch(RandomSearch toCopy) {
        super(toCopy);
        this.trials = toCopy.trials;
        this.searchValues = new ArrayList<Distribution>(toCopy.searchValues.size());
        for (Distribution d : toCopy.searchValues) {
            this.searchValues.add(d.clone());
        }
    }

    public int autoAddParameters(DataSet data) {
        Parameterized obj = this.baseClassifier != null ? (Parameterized)((Object)this.baseClassifier) : (Parameterized)((Object)this.baseRegressor);
        int totalParms = 0;
        for (Parameter param : obj.getParameters()) {
            Distribution dist;
            if (param instanceof DoubleParameter) {
                dist = ((DoubleParameter)param).getGuess(data);
                if (dist == null) continue;
                this.addParameter((DoubleParameter)param, dist);
                ++totalParms;
                continue;
            }
            if (!(param instanceof IntParameter) || (dist = ((IntParameter)param).getGuess(data)) == null) continue;
            this.addParameter((IntParameter)param, dist);
            ++totalParms;
        }
        return totalParms;
    }

    public void setTrials(int trials) {
        if (trials < 1) {
            throw new IllegalArgumentException("number of trials must be positive, not " + trials);
        }
        this.trials = trials;
    }

    public int getTrials() {
        return this.trials;
    }

    public void addParameter(DoubleParameter param, Distribution dist) {
        if (param == null) {
            throw new IllegalArgumentException("null not allowed for parameter");
        }
        this.searchParams.add(param);
        this.searchValues.add(dist.clone());
    }

    public void addParameter(IntParameter param, Distribution dist) {
        if (param == null) {
            throw new IllegalArgumentException("null not allowed for parameter");
        }
        this.searchParams.add(param);
        this.searchValues.add(dist.clone());
    }

    public void addParameter(String name, Distribution dist) {
        Parameter param = this.getParameterByName(name);
        if (param instanceof DoubleParameter) {
            this.addParameter((DoubleParameter)param, dist);
        } else if (param instanceof IntParameter) {
            this.addParameter((IntParameter)param, dist);
        } else {
            throw new IllegalArgumentException("Parameter " + name + " is not for double or int values");
        }
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        ArrayList<ClassificationDataSet> trainCombinations;
        List<ClassificationDataSet> preFolded;
        PriorityQueue bestModels = new PriorityQueue(this.folds, (t, t1) -> {
            double v0 = t.getScoreStats(this.classificationTargetScore).getMean();
            double v1 = t1.getScoreStats(this.classificationTargetScore).getMean();
            int order = this.classificationTargetScore.lowerIsBetter() ? 1 : -1;
            return order * Double.compare(v0, v1);
        });
        ArrayList<Classifier> paramsToEval = new ArrayList<Classifier>();
        Random rand = RandomUtil.getRandom();
        for (int trial = 0; trial < this.trials; ++trial) {
            for (int i = 0; i < this.searchParams.size(); ++i) {
                double sampledValue = this.searchValues.get(i).invCdf(rand.nextDouble());
                Parameter param = (Parameter)this.searchParams.get(i);
                if (param instanceof DoubleParameter) {
                    ((DoubleParameter)param).setValue(sampledValue);
                    continue;
                }
                if (!(param instanceof IntParameter)) continue;
                ((IntParameter)param).setValue((int)Math.round(sampledValue));
            }
            paramsToEval.add(this.baseClassifier.clone());
        }
        if (this.reuseSameCVFolds) {
            preFolded = dataSet.cvSet(this.folds);
            trainCombinations = new ArrayList<ClassificationDataSet>(preFolded.size());
            for (int i = 0; i < preFolded.size(); ++i) {
                trainCombinations.add(ClassificationDataSet.comineAllBut(preFolded, i));
            }
        } else {
            preFolded = null;
            trainCombinations = null;
        }
        ParallelUtils.run(parallel && this.trainModelsInParallel, paramsToEval.size(), indx -> {
            Classifier c = (Classifier)paramsToEval.get(indx);
            ClassificationModelEvaluation cme = new ClassificationModelEvaluation(c, dataSet, !this.trainModelsInParallel && parallel);
            cme.addScorer(this.classificationTargetScore.clone());
            if (this.reuseSameCVFolds) {
                cme.evaluateCrossValidation(preFolded, trainCombinations);
            } else {
                cme.evaluateCrossValidation(this.folds);
            }
            PriorityQueue priorityQueue = bestModels;
            synchronized (priorityQueue) {
                bestModels.add(cme);
            }
        });
        Classifier bestClassifier = ((ClassificationModelEvaluation)bestModels.peek()).getClassifier();
        if (this.trainFinalModel) {
            bestClassifier.train(dataSet, parallel);
        }
        this.trainedClassifier = bestClassifier;
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        ArrayList<RegressionDataSet> trainCombinations;
        List<RegressionDataSet> preFolded;
        PriorityQueue bestModels = new PriorityQueue(this.folds, (t, t1) -> {
            double v0 = t.getScoreStats(this.regressionTargetScore).getMean();
            double v1 = t1.getScoreStats(this.regressionTargetScore).getMean();
            int order = this.regressionTargetScore.lowerIsBetter() ? 1 : -1;
            return order * Double.compare(v0, v1);
        });
        ArrayList<Regressor> paramsToEval = new ArrayList<Regressor>();
        Random rand = RandomUtil.getRandom();
        for (int trial = 0; trial < this.trials; ++trial) {
            for (int i = 0; i < this.searchParams.size(); ++i) {
                double sampledValue = this.searchValues.get(i).invCdf(rand.nextDouble());
                Parameter param = (Parameter)this.searchParams.get(i);
                if (param instanceof DoubleParameter) {
                    ((DoubleParameter)param).setValue(sampledValue);
                    continue;
                }
                if (!(param instanceof IntParameter)) continue;
                ((IntParameter)param).setValue((int)Math.round(sampledValue));
            }
            paramsToEval.add(this.baseRegressor.clone());
        }
        if (this.reuseSameCVFolds) {
            preFolded = dataSet.cvSet(this.folds);
            trainCombinations = new ArrayList<RegressionDataSet>(preFolded.size());
            for (int i = 0; i < preFolded.size(); ++i) {
                trainCombinations.add(RegressionDataSet.comineAllBut(preFolded, i));
            }
        } else {
            preFolded = null;
            trainCombinations = null;
        }
        ParallelUtils.run(parallel && this.trainModelsInParallel, paramsToEval.size(), indx -> {
            Regressor r = (Regressor)paramsToEval.get(indx);
            RegressionModelEvaluation cme = new RegressionModelEvaluation(r, dataSet, !this.trainModelsInParallel && parallel);
            cme.addScorer(this.regressionTargetScore.clone());
            if (this.reuseSameCVFolds) {
                cme.evaluateCrossValidation(preFolded, trainCombinations);
            } else {
                cme.evaluateCrossValidation(this.folds);
            }
            PriorityQueue priorityQueue = bestModels;
            synchronized (priorityQueue) {
                bestModels.add(cme);
            }
        });
        Regressor bestRegressor = ((RegressionModelEvaluation)bestModels.peek()).getRegressor();
        if (this.trainFinalModel) {
            bestRegressor.train(dataSet, parallel);
        }
        this.trainedRegressor = bestRegressor;
    }

    @Override
    public RandomSearch clone() {
        return new RandomSearch(this);
    }
}

