package jsat.parameters;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.ClassificationModelEvaluation;
import jsat.classifiers.Classifier;
import jsat.distributions.Distribution;
import jsat.regression.RegressionDataSet;
import jsat.regression.RegressionModelEvaluation;
import jsat.regression.Regressor;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/parameters/RandomSearch.class */
public class RandomSearch extends ModelSearch {
    private int trials;
    private List<Distribution> searchValues;

    public RandomSearch(Regressor regressor, int i) {
        super(regressor, i);
        this.trials = 25;
        this.searchValues = new ArrayList();
    }

    public RandomSearch(Classifier classifier, int i) {
        super(classifier, i);
        this.trials = 25;
        this.searchValues = new ArrayList();
    }

    public RandomSearch(RandomSearch randomSearch) {
        super(randomSearch);
        this.trials = 25;
        this.trials = randomSearch.trials;
        this.searchValues = new ArrayList(randomSearch.searchValues.size());
        Iterator<Distribution> it = randomSearch.searchValues.iterator();
        while (it.hasNext()) {
            this.searchValues.add(it.next().mo146clone());
        }
    }

    public int autoAddParameters(DataSet dataSet) {
        Distribution guess;
        int i = 0;
        for (Parameter parameter : (this.baseClassifier != null ? (Parameterized) this.baseClassifier : (Parameterized) this.baseRegressor).getParameters()) {
            if (parameter instanceof DoubleParameter) {
                Distribution guess2 = ((DoubleParameter) parameter).getGuess(dataSet);
                if (guess2 != null) {
                    addParameter((DoubleParameter) parameter, guess2);
                    i++;
                }
            } else if ((parameter instanceof IntParameter) && (guess = ((IntParameter) parameter).getGuess(dataSet)) != null) {
                addParameter((IntParameter) parameter, guess);
                i++;
            }
        }
        return i;
    }

    public void setTrials(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("number of trials must be positive, not " + i);
        }
        this.trials = i;
    }

    public int getTrials() {
        return this.trials;
    }

    public void addParameter(DoubleParameter doubleParameter, Distribution distribution) {
        if (doubleParameter == null) {
            throw new IllegalArgumentException("null not allowed for parameter");
        }
        this.searchParams.add(doubleParameter);
        this.searchValues.add(distribution.mo146clone());
    }

    public void addParameter(IntParameter intParameter, Distribution distribution) {
        if (intParameter == null) {
            throw new IllegalArgumentException("null not allowed for parameter");
        }
        this.searchParams.add(intParameter);
        this.searchValues.add(distribution.mo146clone());
    }

    public void addParameter(String str, Distribution distribution) {
        Parameter parameterByName = getParameterByName(str);
        if (parameterByName instanceof DoubleParameter) {
            addParameter((DoubleParameter) parameterByName, distribution);
        } else {
            if (!(parameterByName instanceof IntParameter)) {
                throw new IllegalArgumentException("Parameter " + str + " is not for double or int values");
            }
            addParameter((IntParameter) parameterByName, distribution);
        }
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        List<ClassificationDataSet> list;
        ArrayList arrayList;
        PriorityQueue priorityQueue = new PriorityQueue(this.folds, (classificationModelEvaluation, classificationModelEvaluation2) -> {
            return (this.classificationTargetScore.lowerIsBetter() ? 1 : -1) * Double.compare(classificationModelEvaluation.getScoreStats(this.classificationTargetScore).getMean(), classificationModelEvaluation2.getScoreStats(this.classificationTargetScore).getMean());
        });
        ArrayList arrayList2 = new ArrayList();
        Random random = RandomUtil.getRandom();
        for (int i = 0; i < this.trials; i++) {
            for (int i2 = 0; i2 < this.searchParams.size(); i2++) {
                double invCdf = this.searchValues.get(i2).invCdf(random.nextDouble());
                Parameter parameter = this.searchParams.get(i2);
                if (parameter instanceof DoubleParameter) {
                    ((DoubleParameter) parameter).setValue(invCdf);
                } else if (parameter instanceof IntParameter) {
                    ((IntParameter) parameter).setValue((int) Math.round(invCdf));
                }
            }
            arrayList2.add(this.baseClassifier.mo251clone());
        }
        if (this.reuseSameCVFolds) {
            list = classificationDataSet.cvSet(this.folds);
            arrayList = new ArrayList(list.size());
            for (int i3 = 0; i3 < list.size(); i3++) {
                arrayList.add(ClassificationDataSet.comineAllBut(list, i3));
            }
        } else {
            list = null;
            arrayList = null;
        }
        List<ClassificationDataSet> list2 = list;
        ArrayList arrayList3 = arrayList;
        ParallelUtils.run(z && this.trainModelsInParallel, arrayList2.size(), i4 -> {
            ClassificationModelEvaluation classificationModelEvaluation3 = new ClassificationModelEvaluation((Classifier) arrayList2.get(i4), classificationDataSet, !this.trainModelsInParallel && z);
            classificationModelEvaluation3.addScorer(this.classificationTargetScore.m33clone());
            if (this.reuseSameCVFolds) {
                classificationModelEvaluation3.evaluateCrossValidation((List<ClassificationDataSet>) list2, (List<ClassificationDataSet>) arrayList3);
            } else {
                classificationModelEvaluation3.evaluateCrossValidation(this.folds);
            }
            synchronized (priorityQueue) {
                priorityQueue.add(classificationModelEvaluation3);
            }
        });
        Classifier classifier = ((ClassificationModelEvaluation) priorityQueue.peek()).getClassifier();
        if (this.trainFinalModel) {
            classifier.train(classificationDataSet, z);
        }
        this.trainedClassifier = classifier;
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, boolean z) {
        List<RegressionDataSet> list;
        ArrayList arrayList;
        PriorityQueue priorityQueue = new PriorityQueue(this.folds, (regressionModelEvaluation, regressionModelEvaluation2) -> {
            return (this.regressionTargetScore.lowerIsBetter() ? 1 : -1) * Double.compare(regressionModelEvaluation.getScoreStats(this.regressionTargetScore).getMean(), regressionModelEvaluation2.getScoreStats(this.regressionTargetScore).getMean());
        });
        ArrayList arrayList2 = new ArrayList();
        Random random = RandomUtil.getRandom();
        for (int i = 0; i < this.trials; i++) {
            for (int i2 = 0; i2 < this.searchParams.size(); i2++) {
                double invCdf = this.searchValues.get(i2).invCdf(random.nextDouble());
                Parameter parameter = this.searchParams.get(i2);
                if (parameter instanceof DoubleParameter) {
                    ((DoubleParameter) parameter).setValue(invCdf);
                } else if (parameter instanceof IntParameter) {
                    ((IntParameter) parameter).setValue((int) Math.round(invCdf));
                }
            }
            arrayList2.add(this.baseRegressor.mo251clone());
        }
        if (this.reuseSameCVFolds) {
            list = regressionDataSet.cvSet(this.folds);
            arrayList = new ArrayList(list.size());
            for (int i3 = 0; i3 < list.size(); i3++) {
                arrayList.add(RegressionDataSet.comineAllBut(list, i3));
            }
        } else {
            list = null;
            arrayList = null;
        }
        List<RegressionDataSet> list2 = list;
        ArrayList arrayList3 = arrayList;
        ParallelUtils.run(z && this.trainModelsInParallel, arrayList2.size(), i4 -> {
            RegressionModelEvaluation regressionModelEvaluation3 = new RegressionModelEvaluation((Regressor) arrayList2.get(i4), regressionDataSet, !this.trainModelsInParallel && z);
            regressionModelEvaluation3.addScorer(this.regressionTargetScore.m258clone());
            if (this.reuseSameCVFolds) {
                regressionModelEvaluation3.evaluateCrossValidation((List<RegressionDataSet>) list2, (List<RegressionDataSet>) arrayList3);
            } else {
                regressionModelEvaluation3.evaluateCrossValidation(this.folds);
            }
            synchronized (priorityQueue) {
                priorityQueue.add(regressionModelEvaluation3);
            }
        });
        Regressor regressor = ((RegressionModelEvaluation) priorityQueue.peek()).getRegressor();
        if (this.trainFinalModel) {
            regressor.train(regressionDataSet, z);
        }
        this.trainedRegressor = regressor;
    }

    @Override // jsat.parameters.ModelSearch
    /* renamed from: clone */
    public RandomSearch mo251clone() {
        return new RandomSearch(this);
    }
}
