package jsat.datatransform.featureselection;

import java.util.Arrays;
import java.util.Iterator;
import java.util.Random;
import java.util.Set;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.datatransform.RemoveAttributeTransform;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/datatransform/featureselection/SBS.class */
public class SBS extends RemoveAttributeTransform {
    private static final long serialVersionUID = -2516121100148559742L;
    private double maxDecrease;
    private int folds;
    private int minFeatures;
    private int maxFeatures;
    private Object evaluator;

    private SBS(SBS sbs) {
        super(sbs);
        this.maxDecrease = sbs.maxDecrease;
        this.folds = sbs.folds;
        this.minFeatures = sbs.minFeatures;
        this.maxFeatures = sbs.maxFeatures;
        this.evaluator = sbs.evaluator;
    }

    public SBS(int i, int i2, Classifier classifier, double d) {
        this(i, i2, classifier, 3, d);
    }

    private SBS(int i, int i2, Object obj, int i3, double d) {
        setMaxDecrease(d);
        setMinFeatures(i);
        setMaxFeatures(i2);
        setEvaluator(obj);
        setFolds(i3);
    }

    public SBS(int i, int i2, ClassificationDataSet classificationDataSet, Classifier classifier, int i3, double d) {
        this(i, i2, classifier, i3, d);
        search(classificationDataSet, classifier, i, i2, i3);
    }

    public SBS(int i, int i2, Regressor regressor, double d) {
        this(i, i2, regressor, 3, d);
    }

    public SBS(int i, int i2, RegressionDataSet regressionDataSet, Regressor regressor, int i3, double d) {
        this(i, i2, regressor, i3, d);
        search(regressionDataSet, regressor, i, i2, i3);
    }

    @Override // jsat.datatransform.RemoveAttributeTransform, jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        search(dataSet, this.evaluator, this.minFeatures, this.maxFeatures, this.folds);
    }

    private void search(DataSet dataSet, Object obj, int i, int i2, int i3) {
        Random random = RandomUtil.getRandom();
        int numFeatures = dataSet.getNumFeatures();
        int numCategoricalVars = dataSet.getNumCategoricalVars();
        IntSet intSet = new IntSet();
        ListUtils.addRange(intSet, 0, numFeatures, 1);
        IntSet intSet2 = new IntSet(dataSet.getNumCategoricalVars());
        IntSet intSet3 = new IntSet(dataSet.getNumNumericalVars());
        IntSet intSet4 = new IntSet(dataSet.getNumCategoricalVars());
        IntSet intSet5 = new IntSet(dataSet.getNumNumericalVars());
        ListUtils.addRange(intSet2, 0, numCategoricalVars, 1);
        ListUtils.addRange(intSet3, 0, numFeatures - numCategoricalVars, 1);
        double[] dArr = {Double.POSITIVE_INFINITY};
        while (intSet2.size() + intSet3.size() > i && SBSRemoveFeature(intSet, dataSet, intSet4, intSet5, intSet2, intSet3, obj, i3, random, i2, dArr, this.maxDecrease) >= 0) {
        }
        int i4 = 0;
        this.catIndexMap = new int[intSet2.size()];
        Iterator<Integer> it = intSet2.iterator();
        while (it.hasNext()) {
            int i5 = i4;
            i4++;
            this.catIndexMap[i5] = it.next().intValue();
        }
        Arrays.sort(this.catIndexMap);
        int i6 = 0;
        this.numIndexMap = new int[intSet3.size()];
        Iterator<Integer> it2 = intSet3.iterator();
        while (it2.hasNext()) {
            int i7 = i6;
            i6++;
            this.numIndexMap[i7] = it2.next().intValue();
        }
        Arrays.sort(this.numIndexMap);
    }

    @Override // jsat.datatransform.RemoveAttributeTransform
    public SBS clone() {
        return new SBS(this);
    }

    public Set<Integer> getSelectedCategorical() {
        return new IntSet(IntList.view(this.catIndexMap, this.catIndexMap.length));
    }

    public Set<Integer> getSelectedNumerical() {
        return new IntSet(IntList.view(this.numIndexMap, this.numIndexMap.length));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static int SBSRemoveFeature(Set<Integer> set, DataSet dataSet, Set<Integer> set2, Set<Integer> set3, Set<Integer> set4, Set<Integer> set5, Object obj, int i, Random random, int i2, double[] dArr, double d) {
        int i3 = -1;
        int numCategoricalVars = dataSet.getNumCategoricalVars();
        double d2 = Double.POSITIVE_INFINITY;
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            DataSet shallowClone2 = dataSet.shallowClone2();
            SFS.addFeature(intValue, numCategoricalVars, set2, set3);
            shallowClone2.applyTransform(new RemoveAttributeTransform(shallowClone2, set2, set3));
            double score = SFS.getScore(shallowClone2, obj, i, random);
            if (score < d2) {
                d2 = score;
                i3 = intValue;
            }
            SFS.removeFeature(intValue, numCategoricalVars, set2, set3);
        }
        if (set4.size() + set5.size() <= i2 && dArr[0] - d2 <= (-d)) {
            return -1;
        }
        dArr[0] = d2;
        SFS.removeFeature(i3, numCategoricalVars, set4, set5);
        SFS.addFeature(i3, numCategoricalVars, set2, set3);
        set.remove(Integer.valueOf(i3));
        return i3;
    }

    public void setMaxDecrease(double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("Decarese must be a positive value, not " + d);
        }
        this.maxDecrease = d;
    }

    public double getMaxDecrease() {
        return this.maxDecrease;
    }

    public void setMinFeatures(int i) {
        this.minFeatures = i;
    }

    public int getMinFeatures() {
        return this.minFeatures;
    }

    public void setMaxFeatures(int i) {
        this.maxFeatures = i;
    }

    public int getMaxFeatures() {
        return this.maxFeatures;
    }

    public void setFolds(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Number of CV folds must be positive, not " + i);
        }
        this.folds = i;
    }

    public int getFolds() {
        return this.folds;
    }

    private void setEvaluator(Object obj) {
        this.evaluator = obj;
    }
}
