package jsat.parameters;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.ClassificationModelEvaluation;
import jsat.classifiers.Classifier;
import jsat.classifiers.WarmClassifier;
import jsat.distributions.Distribution;
import jsat.regression.RegressionDataSet;
import jsat.regression.RegressionModelEvaluation;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.utils.DoubleList;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:jsat/parameters/GridSearch.class */
public class GridSearch extends ModelSearch {
    private static final long serialVersionUID = -1987196172499143753L;
    private List<List<Double>> searchValues;
    private boolean useWarmStarts;

    public GridSearch(Regressor regressor, int i) {
        super(regressor, i);
        this.useWarmStarts = true;
        this.searchValues = new ArrayList();
    }

    public GridSearch(Classifier classifier, int i) {
        super(classifier, i);
        this.useWarmStarts = true;
        this.searchValues = new ArrayList();
    }

    public GridSearch(GridSearch gridSearch) {
        super(gridSearch);
        this.useWarmStarts = true;
        this.useWarmStarts = gridSearch.useWarmStarts;
        if (gridSearch.searchValues != null) {
            this.searchValues = new ArrayList();
            Iterator<List<Double>> it = gridSearch.searchValues.iterator();
            while (it.hasNext()) {
                this.searchValues.add(new DoubleList(it.next()));
            }
        }
    }

    public int autoAddParameters(DataSet dataSet) {
        return autoAddParameters(dataSet, 10);
    }

    public int autoAddParameters(DataSet dataSet, int i) {
        Distribution guess;
        Parameterized parameterized = this.baseClassifier != null ? (Parameterized) this.baseClassifier : (Parameterized) this.baseRegressor;
        int i2 = 0;
        for (Parameter parameter : parameterized.getParameters()) {
            if (parameter instanceof DoubleParameter) {
                if (((DoubleParameter) parameter).getGuess(dataSet) != null) {
                    i2++;
                }
            } else if ((parameter instanceof IntParameter) && ((IntParameter) parameter).getGuess(dataSet) != null) {
                i2++;
            }
        }
        if (i2 < 1) {
            return 0;
        }
        double[] dArr = new double[i];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = (i3 + 1.0d) / (i + 1.0d);
        }
        for (Parameter parameter2 : parameterized.getParameters()) {
            if (parameter2 instanceof DoubleParameter) {
                Distribution guess2 = ((DoubleParameter) parameter2).getGuess(dataSet);
                if (guess2 != null) {
                    double[] dArr2 = new double[i];
                    for (int i4 = 0; i4 < dArr2.length; i4++) {
                        dArr2[i4] = guess2.invCdf(dArr[i4]);
                    }
                    addParameter((DoubleParameter) parameter2, dArr2);
                }
            } else if ((parameter2 instanceof IntParameter) && (guess = ((IntParameter) parameter2).getGuess(dataSet)) != null) {
                int[] iArr = new int[i];
                for (int i5 = 0; i5 < iArr.length; i5++) {
                    iArr[i5] = (int) Math.round(guess.invCdf(dArr[i5]));
                }
                addParameter((IntParameter) parameter2, iArr);
            }
        }
        return i2;
    }

    public void setUseWarmStarts(boolean z) {
        this.useWarmStarts = z;
    }

    public boolean isUseWarmStarts() {
        return this.useWarmStarts;
    }

    public void addParameter(DoubleParameter doubleParameter, double... dArr) {
        if (doubleParameter == null) {
            throw new IllegalArgumentException("null not allowed for parameter");
        }
        this.searchParams.add(doubleParameter);
        DoubleList doubleList = new DoubleList(dArr.length);
        for (double d : dArr) {
            doubleList.add(d);
        }
        Arrays.sort(doubleList.getBackingArray());
        if (doubleParameter.isWarmParameter() && !doubleParameter.preferredLowToHigh()) {
            Collections.reverse(doubleList);
        }
        if (doubleParameter.isWarmParameter()) {
            this.searchValues.add(0, doubleList);
        } else {
            this.searchValues.add(doubleList);
        }
    }

    public void addParameter(String str, double... dArr) {
        Parameter parameterByName = getParameterByName(str);
        if (!(parameterByName instanceof DoubleParameter)) {
            throw new IllegalArgumentException("Parameter " + str + " is not for double values");
        }
        addParameter((DoubleParameter) parameterByName, dArr);
    }

    public void addParameter(IntParameter intParameter, int... iArr) {
        this.searchParams.add(intParameter);
        DoubleList doubleList = new DoubleList(iArr.length);
        for (int i : iArr) {
            doubleList.add(i);
        }
        Arrays.sort(doubleList.getBackingArray());
        if (intParameter.isWarmParameter() && !intParameter.preferredLowToHigh()) {
            Collections.reverse(doubleList);
        }
        if (intParameter.isWarmParameter()) {
            this.searchValues.add(0, doubleList);
        } else {
            this.searchValues.add(doubleList);
        }
    }

    public void addParameter(String str, int... iArr) {
        Parameter parameterByName = getParameterByName(str);
        if (!(parameterByName instanceof IntParameter)) {
            throw new IllegalArgumentException("Parameter " + str + " is not for int values");
        }
        addParameter((IntParameter) parameterByName, iArr);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, boolean z) {
        List<RegressionDataSet> list;
        ArrayList arrayList;
        PriorityQueue priorityQueue = new PriorityQueue(this.folds, (regressionModelEvaluation, regressionModelEvaluation2) -> {
            return (this.regressionTargetScore.lowerIsBetter() ? 1 : -1) * Double.compare(regressionModelEvaluation.getScoreStats(this.regressionTargetScore).getMean(), regressionModelEvaluation2.getScoreStats(this.regressionTargetScore).getMean());
        });
        int[] iArr = new int[this.searchParams.size()];
        ArrayList arrayList2 = new ArrayList();
        do {
            setParameters(iArr);
            arrayList2.add(this.baseRegressor.mo251clone());
        } while (!incrementCombination(iArr));
        if (this.reuseSameCVFolds) {
            list = regressionDataSet.cvSet(this.folds);
            arrayList = new ArrayList(list.size());
            for (int i = 0; i < list.size(); i++) {
                arrayList.add(RegressionDataSet.comineAllBut(list, i));
            }
        } else {
            list = null;
            arrayList = null;
        }
        if (!(this.useWarmStarts && (this.baseRegressor instanceof WarmRegressor)) || (((WarmRegressor) this.baseRegressor).warmFromSameDataOnly() && !this.reuseSameCVFolds)) {
            List<RegressionDataSet> list2 = list;
            ArrayList arrayList3 = arrayList;
            ParallelUtils.run(z && this.trainModelsInParallel, arrayList2.size(), i2 -> {
                RegressionModelEvaluation regressionModelEvaluation3 = new RegressionModelEvaluation((Regressor) arrayList2.get(i2), regressionDataSet, !this.trainModelsInParallel && z);
                regressionModelEvaluation3.addScorer(this.regressionTargetScore.m258clone());
                if (this.reuseSameCVFolds) {
                    regressionModelEvaluation3.evaluateCrossValidation((List<RegressionDataSet>) list2, (List<RegressionDataSet>) arrayList3);
                } else {
                    regressionModelEvaluation3.evaluateCrossValidation(this.folds);
                }
                synchronized (priorityQueue) {
                    priorityQueue.add(regressionModelEvaluation3);
                }
            });
        } else {
            List<RegressionDataSet> list3 = list;
            ArrayList arrayList4 = arrayList;
            ParallelUtils.run(z && this.trainModelsInParallel, arrayList2.size(), (i3, i4) -> {
                Regressor[] regressorArr = null;
                Iterator it = arrayList2.subList(i3, i4).iterator();
                while (it.hasNext()) {
                    RegressionModelEvaluation regressionModelEvaluation3 = new RegressionModelEvaluation((Regressor) it.next(), regressionDataSet, !this.trainModelsInParallel && z);
                    regressionModelEvaluation3.setKeepModels(true);
                    regressionModelEvaluation3.setWarmModels(regressorArr);
                    regressionModelEvaluation3.addScorer(this.regressionTargetScore.m258clone());
                    if (this.reuseSameCVFolds) {
                        regressionModelEvaluation3.evaluateCrossValidation((List<RegressionDataSet>) list3, (List<RegressionDataSet>) arrayList4);
                    } else {
                        regressionModelEvaluation3.evaluateCrossValidation(this.folds);
                    }
                    regressorArr = regressionModelEvaluation3.getKeptModels();
                    synchronized (priorityQueue) {
                        priorityQueue.add(regressionModelEvaluation3);
                    }
                }
            });
        }
        Regressor regressor = ((RegressionModelEvaluation) priorityQueue.peek()).getRegressor();
        if (this.trainFinalModel) {
            if (this.useWarmStarts && (regressor instanceof WarmRegressor) && !((WarmRegressor) regressor).warmFromSameDataOnly()) {
                WarmRegressor warmRegressor = (WarmRegressor) regressor;
                warmRegressor.train(regressionDataSet, warmRegressor.mo251clone(), z);
            } else {
                regressor.train(regressionDataSet, z);
            }
        }
        this.trainedRegressor = regressor;
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        List<ClassificationDataSet> list;
        ArrayList arrayList;
        PriorityQueue priorityQueue = new PriorityQueue(this.folds, (classificationModelEvaluation, classificationModelEvaluation2) -> {
            return (this.classificationTargetScore.lowerIsBetter() ? 1 : -1) * Double.compare(classificationModelEvaluation.getScoreStats(this.classificationTargetScore).getMean(), classificationModelEvaluation2.getScoreStats(this.classificationTargetScore).getMean());
        });
        int[] iArr = new int[this.searchParams.size()];
        ArrayList arrayList2 = new ArrayList();
        do {
            setParameters(iArr);
            arrayList2.add(this.baseClassifier.mo251clone());
        } while (!incrementCombination(iArr));
        if (this.reuseSameCVFolds) {
            list = classificationDataSet.cvSet(this.folds);
            arrayList = new ArrayList(list.size());
            for (int i = 0; i < list.size(); i++) {
                arrayList.add(ClassificationDataSet.comineAllBut(list, i));
            }
        } else {
            list = null;
            arrayList = null;
        }
        if (!(this.useWarmStarts && (this.baseClassifier instanceof WarmClassifier)) || (((WarmClassifier) this.baseClassifier).warmFromSameDataOnly() && !this.reuseSameCVFolds)) {
            List<ClassificationDataSet> list2 = list;
            ArrayList arrayList3 = arrayList;
            ParallelUtils.run(z && this.trainModelsInParallel, arrayList2.size(), i2 -> {
                ClassificationModelEvaluation classificationModelEvaluation3 = new ClassificationModelEvaluation((Classifier) arrayList2.get(i2), classificationDataSet, !this.trainModelsInParallel && z);
                classificationModelEvaluation3.addScorer(this.classificationTargetScore.m33clone());
                if (this.reuseSameCVFolds) {
                    classificationModelEvaluation3.evaluateCrossValidation((List<ClassificationDataSet>) list2, (List<ClassificationDataSet>) arrayList3);
                } else {
                    classificationModelEvaluation3.evaluateCrossValidation(this.folds);
                }
                synchronized (priorityQueue) {
                    priorityQueue.add(classificationModelEvaluation3);
                }
            });
        } else {
            List<ClassificationDataSet> list3 = list;
            ArrayList arrayList4 = arrayList;
            ParallelUtils.run(z && this.trainModelsInParallel, arrayList2.size(), (i3, i4) -> {
                Classifier[] classifierArr = null;
                Iterator it = arrayList2.subList(i3, i4).iterator();
                while (it.hasNext()) {
                    ClassificationModelEvaluation classificationModelEvaluation3 = new ClassificationModelEvaluation((Classifier) it.next(), classificationDataSet, !this.trainModelsInParallel && z);
                    classificationModelEvaluation3.setKeepModels(true);
                    classificationModelEvaluation3.setWarmModels(classifierArr);
                    classificationModelEvaluation3.addScorer(this.classificationTargetScore.m33clone());
                    if (this.reuseSameCVFolds) {
                        classificationModelEvaluation3.evaluateCrossValidation((List<ClassificationDataSet>) list3, (List<ClassificationDataSet>) arrayList4);
                    } else {
                        classificationModelEvaluation3.evaluateCrossValidation(this.folds);
                    }
                    classifierArr = classificationModelEvaluation3.getKeptModels();
                    synchronized (priorityQueue) {
                        priorityQueue.add(classificationModelEvaluation3);
                    }
                }
            });
        }
        Classifier classifier = ((ClassificationModelEvaluation) priorityQueue.peek()).getClassifier();
        if (this.trainFinalModel) {
            if (this.useWarmStarts && (classifier instanceof WarmClassifier) && !((WarmClassifier) classifier).warmFromSameDataOnly()) {
                WarmClassifier warmClassifier = (WarmClassifier) classifier;
                warmClassifier.train(classificationDataSet, warmClassifier.mo251clone(), z);
            } else {
                classifier.train(classificationDataSet, z);
            }
        }
        this.trainedClassifier = classifier;
    }

    @Override // jsat.parameters.ModelSearch
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public GridSearch mo251clone() {
        return new GridSearch(this);
    }

    private boolean incrementCombination(int[] iArr) {
        iArr[0] = iArr[0] + 1;
        int i = 0;
        while (i < iArr.length - 1 && iArr[i] >= this.searchValues.get(i).size()) {
            iArr[i] = 0;
            i++;
            iArr[i] = iArr[i] + 1;
        }
        return iArr[iArr.length - 1] >= this.searchValues.get(iArr.length - 1).size();
    }

    private void setParameters(int[] iArr) {
        for (int i = 0; i < iArr.length; i++) {
            Parameter parameter = this.searchParams.get(i);
            if (parameter instanceof DoubleParameter) {
                ((DoubleParameter) parameter).setValue(this.searchValues.get(i).get(iArr[i]).doubleValue());
            } else if (parameter instanceof IntParameter) {
                ((IntParameter) parameter).setValue(this.searchValues.get(i).get(iArr[i]).intValue());
            }
        }
    }
}
