package jsat.classifiers.trees;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.trees.ImpurityScore;
import jsat.exceptions.FailedToFitException;
import jsat.linear.Vec;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.PairedReturn;
import jsat.utils.QuickSort;
import jsat.utils.concurrent.AtomicDouble;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:jsat/classifiers/trees/DecisionStump.class */
public class DecisionStump implements Classifier, Regressor, Parameterized {
    private static final long serialVersionUID = -2849268862089019514L;
    private int splittingAttribute;
    private CategoricalData predicting;
    private CategoricalData[] catAttributes;
    private int numNumericFeatures;
    private List<Double> boundries;
    private List<Integer> owners;
    private CategoricalResults[] results;
    protected double[] pathRatio;
    private double[] regressionResults;
    private static final double almost0 = 1.0E-6d;
    private static final double almost1 = 0.999999d;
    private int minResultSplitSize = 10;
    private ImpurityScore.ImpurityMeasure gainMethod = ImpurityScore.ImpurityMeasure.INFORMATION_GAIN_RATIO;
    private boolean removeContinuousAttributes = false;

    public void setRemoveContinuousAttributes(boolean z) {
        this.removeContinuousAttributes = z;
    }

    public void setGainMethod(ImpurityScore.ImpurityMeasure impurityMeasure) {
        this.gainMethod = impurityMeasure;
    }

    public ImpurityScore.ImpurityMeasure getGainMethod() {
        return this.gainMethod;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int numNumeric() {
        return this.numNumericFeatures;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int numCategorical() {
        return this.catAttributes.length;
    }

    public void setMinResultSplitSize(int i) {
        if (i <= 1) {
            throw new ArithmeticException("Min split size must be a positive value ");
        }
        this.minResultSplitSize = i;
    }

    public int getMinResultSplitSize() {
        return this.minResultSplitSize;
    }

    public int getSplittingAttribute() {
        return this.splittingAttribute < this.catAttributes.length ? this.numNumericFeatures + this.splittingAttribute : this.splittingAttribute - this.catAttributes.length;
    }

    public void setPredicting(CategoricalData categoricalData) {
        this.predicting = categoricalData;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        if (this.regressionResults == null) {
            throw new RuntimeException("Decusion stump has not been trained for regression");
        }
        int whichPath = whichPath(dataPoint);
        if (whichPath >= 0) {
            return this.regressionResults[whichPath];
        }
        double d = 0.0d;
        for (int i = 0; i < this.pathRatio.length; i++) {
            d += this.pathRatio[i] * this.regressionResults[i];
        }
        return d;
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, boolean z) {
        IntSet intSet = new IntSet(regressionDataSet.getNumFeatures());
        for (int i = 0; i < regressionDataSet.getNumFeatures(); i++) {
            intSet.add((IntSet) Integer.valueOf(i));
        }
        if (trainR(regressionDataSet.getDPPList(), intSet, z) == null) {
            throw new FailedToFitException("Tree could not be fit, make sure your data is good. Potentially file a bug");
        }
    }

    protected double getGain(ImpurityScore impurityScore, List<List<DataPointPair<Integer>>> list) {
        return ImpurityScore.gain(impurityScore, getSplitScores(list));
    }

    private ImpurityScore[] getSplitScores(List<List<DataPointPair<Integer>>> list) {
        ImpurityScore[] impurityScoreArr = new ImpurityScore[list.size()];
        for (int i = 0; i < list.size(); i++) {
            impurityScoreArr[i] = getClassGainScore(list.get(i));
        }
        return impurityScoreArr;
    }

    public int whichPath(DataPoint dataPoint) {
        int numberOfPaths = getNumberOfPaths();
        if (numberOfPaths < 0) {
            return numberOfPaths;
        }
        if (numberOfPaths == 1) {
            return 0;
        }
        if (this.splittingAttribute < this.catAttributes.length) {
            return dataPoint.getCategoricalValue(this.splittingAttribute);
        }
        double d = dataPoint.getNumericalValues().get(this.splittingAttribute - this.catAttributes.length);
        if (Double.isNaN(d)) {
            return -1;
        }
        if (this.results == null) {
            return (this.regressionResults.length != 1 && d > this.regressionResults[2]) ? 1 : 0;
        }
        int binarySearch = Collections.binarySearch(this.boundries, Double.valueOf(d));
        return this.owners.get(binarySearch < 0 ? (-binarySearch) - 1 : binarySearch).intValue();
    }

    public int getNumberOfPaths() {
        if (this.results != null) {
            return this.results.length;
        }
        if (this.catAttributes == null) {
            return Integer.MIN_VALUE;
        }
        if (this.regressionResults.length == 1) {
            return 1;
        }
        if (this.splittingAttribute < this.catAttributes.length) {
            return this.catAttributes[this.splittingAttribute].getNumOfCategories();
        }
        return 2;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.results == null) {
            throw new RuntimeException("DecisionStump has not been trained for classification");
        }
        int whichPath = whichPath(dataPoint);
        if (whichPath >= 0) {
            return this.results[whichPath];
        }
        Vec mo46clone = this.results[0].getVecView().mo46clone();
        mo46clone.mutableMultiply(this.pathRatio[0]);
        for (int i = 1; i < this.results.length; i++) {
            mo46clone.mutableAdd(this.pathRatio[i], this.results[i].getVecView());
        }
        return new CategoricalResults(mo46clone.arrayCopy());
    }

    public CategoricalResults result(int i) {
        if (i < 0 || i >= getNumberOfPaths()) {
            throw new IndexOutOfBoundsException("Invalid path, can to return a result for path " + i);
        }
        return this.results[i];
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        IntSet intSet = new IntSet(classificationDataSet.getNumFeatures());
        for (int i = 0; i < classificationDataSet.getNumFeatures(); i++) {
            intSet.add((IntSet) Integer.valueOf(i));
        }
        this.predicting = classificationDataSet.getPredicting();
        trainC(classificationDataSet.getAsDPPList(), intSet, z);
    }

    public List<List<DataPointPair<Integer>>> trainC(List<DataPointPair<Integer>> list, Set<Integer> set) {
        return trainC(list, set, false);
    }

    public List<List<DataPointPair<Integer>>> trainC(List<DataPointPair<Integer>> list, Set<Integer> set, boolean z) {
        if (this.predicting == null) {
            throw new RuntimeException("Predicting value has not been set");
        }
        this.catAttributes = list.get(0).getDataPoint().getCategoricalData();
        this.numNumericFeatures = list.get(0).getVector().length();
        ImpurityScore classGainScore = getClassGainScore(list);
        if (classGainScore.getScore() == 0.0d || list.size() < this.minResultSplitSize * 2) {
            this.results = new CategoricalResults[1];
            this.results[0] = new CategoricalResults(this.predicting.getNumOfCategories());
            this.results[0].setProb(list.get(0).getPair().intValue(), 1.0d);
            this.pathRatio = new double[]{0.0d};
            ArrayList arrayList = new ArrayList();
            arrayList.add(list);
            return arrayList;
        }
        List<List<DataPointPair<Integer>>> synchronizedList = Collections.synchronizedList(new ArrayList());
        AtomicDouble atomicDouble = new AtomicDouble(-1.0d);
        DoubleList doubleList = new DoubleList();
        this.splittingAttribute = -1;
        CountDownLatch countDownLatch = new CountDownLatch(set.size());
        ThreadLocal withInitial = ThreadLocal.withInitial(() -> {
            return new ArrayList(list);
        });
        ExecutorService fakeExecutor = z ? ParallelUtils.CACHED_THREAD_POOL : new FakeExecutor();
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            fakeExecutor.submit(() -> {
                List<List<DataPointPair<Integer>>> listOfLists;
                ImpurityScore[] impurityScoreArr;
                double gain;
                List<DataPointPair<Integer>> list2 = (List) withInitial.get();
                int i = intValue;
                double[] dArr = {Double.NaN};
                dArr[0] = Double.NaN;
                PairedReturn<List<Double>, List<Integer>> pairedReturn = null;
                double d = 1.0d;
                if (i < this.catAttributes.length) {
                    listOfLists = listOfLists(this.catAttributes[i].getNumOfCategories());
                    impurityScoreArr = new ImpurityScore[listOfLists.size()];
                    for (int i2 = 0; i2 < impurityScoreArr.length; i2++) {
                        impurityScoreArr[i2] = new ImpurityScore(this.predicting.getNumOfCategories(), this.gainMethod);
                    }
                    ArrayList arrayList2 = new ArrayList(0);
                    double d2 = 0.0d;
                    for (DataPointPair<Integer> dataPointPair : list2) {
                        int categoricalValue = dataPointPair.getDataPoint().getCategoricalValue(i);
                        double weight = dataPointPair.getDataPoint().getWeight();
                        if (categoricalValue >= 0) {
                            listOfLists.get(categoricalValue).add(dataPointPair);
                            impurityScoreArr[categoricalValue].addPoint(weight, dataPointPair.getPair().intValue());
                        } else {
                            arrayList2.add(dataPointPair);
                            d2 += weight;
                        }
                    }
                    int i3 = 0;
                    Iterator<List<DataPointPair<Integer>>> it2 = listOfLists.iterator();
                    while (it2.hasNext()) {
                        if (!it2.next().isEmpty()) {
                            i3++;
                        }
                    }
                    if (i3 <= 1) {
                        countDownLatch.countDown();
                        return;
                    }
                    if (d2 > 0.0d) {
                        double sumOfWeights = classGainScore.getSumOfWeights() - d2;
                        d = sumOfWeights / classGainScore.getSumOfWeights();
                        double[] dArr2 = new double[impurityScoreArr.length];
                        for (int i4 = 0; i4 < dArr2.length; i4++) {
                            dArr2[i4] = impurityScoreArr[i4].getSumOfWeights() / sumOfWeights;
                        }
                        distributMissing(listOfLists, dArr2, arrayList2);
                    }
                } else {
                    int length = i - this.catAttributes.length;
                    int numOfCategories = this.predicting.getNumOfCategories();
                    listOfLists = listOfLists(2);
                    impurityScoreArr = new ImpurityScore[2];
                    pairedReturn = createNumericCSplit(list2, numOfCategories, length, listOfLists, classGainScore, dArr, impurityScoreArr);
                    if (pairedReturn == null) {
                        countDownLatch.countDown();
                        return;
                    }
                    i = length + this.catAttributes.length;
                }
                if (Double.isNaN(dArr[0])) {
                    if (impurityScoreArr == null) {
                        impurityScoreArr = getSplitScores(listOfLists);
                    }
                    gain = ImpurityScore.gain(classGainScore, d, impurityScoreArr);
                } else {
                    gain = dArr[0];
                }
                if (gain > atomicDouble.get()) {
                    synchronized (doubleList) {
                        if (gain > atomicDouble.get()) {
                            atomicDouble.set(gain);
                            this.splittingAttribute = i;
                            synchronizedList.clear();
                            synchronizedList.addAll(listOfLists);
                            doubleList.clear();
                            double d3 = 1.0E-8d;
                            for (int i5 = 0; i5 < impurityScoreArr.length; i5++) {
                                d3 += impurityScoreArr[i5].getSumOfWeights();
                                doubleList.add(impurityScoreArr[i5].getSumOfWeights());
                            }
                            for (int i6 = 0; i6 < impurityScoreArr.length; i6++) {
                                doubleList.set(i6, doubleList.getD(i6) / d3);
                            }
                            if (i >= this.catAttributes.length) {
                                this.boundries = pairedReturn.getFirstItem();
                                this.owners = pairedReturn.getSecondItem();
                            }
                        }
                    }
                }
                countDownLatch.countDown();
            });
        }
        try {
            countDownLatch.await();
            if (this.splittingAttribute == -1) {
                synchronizedList.clear();
                synchronizedList.add(list);
                CategoricalResults categoricalResults = new CategoricalResults(this.predicting.getNumOfCategories());
                Iterator<DataPointPair<Integer>> it2 = list.iterator();
                while (it2.hasNext()) {
                    categoricalResults.incProb(it2.next().getPair().intValue(), 1.0d);
                }
                categoricalResults.normalize();
                this.results = new CategoricalResults[]{categoricalResults};
                this.pathRatio = new double[]{1.0d};
                return synchronizedList;
            }
            if (this.splittingAttribute < this.catAttributes.length || this.removeContinuousAttributes) {
                set.remove(Integer.valueOf(this.splittingAttribute));
            }
            this.results = new CategoricalResults[synchronizedList.size()];
            this.pathRatio = doubleList.getVecView().arrayCopy();
            for (int i = 0; i < synchronizedList.size(); i++) {
                this.results[i] = new CategoricalResults(this.predicting.getNumOfCategories());
                for (DataPointPair<Integer> dataPointPair : synchronizedList.get(i)) {
                    this.results[i].incProb(dataPointPair.getPair().intValue(), dataPointPair.getDataPoint().getWeight());
                }
                this.results[i].normalize();
            }
            return synchronizedList;
        } catch (InterruptedException e) {
            Logger.getLogger(DecisionStump.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            throw new FailedToFitException(e);
        }
    }

    private PairedReturn<List<Double>, List<Integer>> createNumericCSplit(List<DataPointPair<Integer>> list, int i, int i2, List<List<DataPointPair<Integer>>> list2, ImpurityScore impurityScore, double[] dArr, ImpurityScore[] impurityScoreArr) {
        double[] dArr2 = new double[list.size()];
        int i3 = 0;
        int i4 = 0;
        while (i4 < list.size() - i3) {
            double d = list.get(i4).getVector().get(i2);
            if (Double.isNaN(d)) {
                Collections.swap(list, (dArr2.length - i3) - 1, i4);
                i3++;
                i4--;
            } else {
                dArr2[i4] = d;
            }
            i4++;
        }
        QuickSort.sort(dArr2, 0, dArr2.length - i3, Arrays.asList(list));
        double d2 = Double.NEGATIVE_INFINITY;
        double d3 = Double.NEGATIVE_INFINITY;
        int i5 = -1;
        ImpurityScore m109clone = impurityScore.m109clone();
        ImpurityScore impurityScore2 = new ImpurityScore(i, this.gainMethod);
        double d4 = 0.0d;
        for (int size = list.size() - i3; size < list.size(); size++) {
            double weight = list.get(size).getDataPoint().getWeight();
            d4 += weight;
            m109clone.removePoint(weight, list.get(size).getPair().intValue());
        }
        double sumOfWeights = m109clone.getSumOfWeights() / (m109clone.getSumOfWeights() + d4);
        for (int i6 = 0; i6 < this.minResultSplitSize; i6++) {
            if (i6 >= list.size()) {
                System.out.println("WHAT?");
            }
            double weight2 = list.get(i6).getDataPoint().getWeight();
            int intValue = list.get(i6).getPair().intValue();
            impurityScore2.addPoint(weight2, intValue);
            m109clone.removePoint(weight2, intValue);
        }
        for (int i7 = this.minResultSplitSize; i7 < ((list.size() - this.minResultSplitSize) - 1) - i3; i7++) {
            DataPointPair<Integer> dataPointPair = list.get(i7);
            m109clone.removePoint(dataPointPair.getDataPoint(), dataPointPair.getPair().intValue());
            impurityScore2.addPoint(dataPointPair.getDataPoint(), dataPointPair.getPair().intValue());
            double d5 = dArr2[i7];
            double d6 = dArr2[i7 + 1];
            if (d6 - d5 >= 1.0E-14d) {
                double gain = ImpurityScore.gain(impurityScore, sumOfWeights, impurityScore2, m109clone);
                if (gain >= d2) {
                    d2 = gain;
                    d3 = (d5 + d6) / 2.0d;
                    i5 = i7 + 1;
                    impurityScoreArr[0] = impurityScore2.m109clone();
                    impurityScoreArr[1] = m109clone.m109clone();
                }
            }
        }
        if (i5 == -1) {
            return null;
        }
        if (dArr != null) {
            dArr[0] = d2;
        }
        list2.set(0, new ArrayList(list.subList(0, i5)));
        list2.set(1, new ArrayList(list.subList(i5, list.size() - i3)));
        if (i3 > 0) {
            double sumOfWeights2 = impurityScore2.getSumOfWeights() / ((impurityScore2.getSumOfWeights() + m109clone.getSumOfWeights()) + 0.0d);
            distributMissing(list2, new double[]{sumOfWeights2, 1.0d - sumOfWeights2}, list.subList(list.size() - i3, list.size()));
        }
        return new PairedReturn<>(Arrays.asList(Double.valueOf(d3), Double.valueOf(Double.POSITIVE_INFINITY)), Arrays.asList(0, 1));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static <T> void distributMissing(List<List<DataPointPair<T>>> list, List<DataPointPair<T>> list2) {
        double[] dArr = new double[list.size()];
        for (int i = 0; i < list.size(); i++) {
            Iterator<DataPointPair<T>> it = list.get(i).iterator();
            while (it.hasNext()) {
                int i2 = i;
                dArr[i2] = dArr[i2] + it.next().getDataPoint().getWeight();
            }
        }
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] / d;
        }
        distributMissing(list, dArr, list2);
    }

    protected static <T> void distributMissing(List<List<DataPointPair<T>>> list, double[] dArr, List<DataPointPair<T>> list2) {
        for (DataPointPair<T> dataPointPair : list2) {
            DataPoint dataPoint = dataPointPair.getDataPoint();
            Vec numericalValues = dataPoint.getNumericalValues();
            int[] categoricalValues = dataPoint.getCategoricalValues();
            CategoricalData[] categoricalData = dataPoint.getCategoricalData();
            for (int i = 0; i < dArr.length; i++) {
                double weight = dArr[i] * dataPoint.getWeight();
                if (!Double.isNaN(weight) && weight > 1.0E-13d) {
                    list.get(i).add(new DataPointPair<>(new DataPoint(numericalValues, categoricalValues, categoricalData, weight), dataPointPair.getPair()));
                }
            }
        }
    }

    public List<List<DataPointPair<Double>>> trainR(List<DataPointPair<Double>> list, Set<Integer> set) {
        return trainR(list, set, false);
    }

    public List<List<DataPointPair<Double>>> trainR(List<DataPointPair<Double>> list, Set<Integer> set, boolean z) {
        this.catAttributes = list.get(0).getDataPoint().getCategoricalData();
        this.numNumericFeatures = list.get(0).getVector().length();
        if (list.size() <= this.minResultSplitSize * 2) {
            this.splittingAttribute = this.catAttributes.length;
            this.regressionResults = new double[1];
            double d = 0.0d;
            double d2 = 0.0d;
            for (DataPointPair<Double> dataPointPair : list) {
                double weight = dataPointPair.getDataPoint().getWeight();
                d += dataPointPair.getPair().doubleValue() * weight;
                d2 += weight;
            }
            this.regressionResults[0] = d / d2;
            ArrayList arrayList = new ArrayList(1);
            arrayList.add(list);
            return arrayList;
        }
        ArrayList arrayList2 = new ArrayList();
        AtomicDouble atomicDouble = new AtomicDouble(Double.MAX_VALUE);
        ThreadLocal withInitial = ThreadLocal.withInitial(() -> {
            return new ArrayList(list);
        });
        ExecutorService fakeExecutor = z ? ParallelUtils.CACHED_THREAD_POOL : new FakeExecutor();
        CountDownLatch countDownLatch = new CountDownLatch(set.size());
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            fakeExecutor.submit(() -> {
                double d3;
                double[] dArr;
                double[] dArr2;
                List<DataPointPair<Double>> list2 = (List) withInitial.get();
                List<List<DataPointPair<Double>>> list3 = null;
                if (intValue < this.catAttributes.length) {
                    list3 = listOfListsD(this.catAttributes[intValue].getNumOfCategories());
                    OnLineStatistics[] onLineStatisticsArr = new OnLineStatistics[list3.size()];
                    dArr2 = new double[list3.size()];
                    for (int i = 0; i < list3.size(); i++) {
                        onLineStatisticsArr[i] = new OnLineStatistics();
                    }
                    ArrayList arrayList3 = new ArrayList(0);
                    for (DataPointPair<Double> dataPointPair2 : list2) {
                        int categoricalValue = dataPointPair2.getDataPoint().getCategoricalValue(intValue);
                        if (categoricalValue >= 0) {
                            list3.get(categoricalValue).add(dataPointPair2);
                            onLineStatisticsArr[categoricalValue].add(dataPointPair2.getPair().doubleValue(), dataPointPair2.getDataPoint().getWeight());
                        } else {
                            arrayList3.add(dataPointPair2);
                        }
                    }
                    dArr = new double[onLineStatisticsArr.length];
                    d3 = 0.0d;
                    double d4 = 0.0d;
                    for (int i2 = 0; i2 < onLineStatisticsArr.length; i2++) {
                        double sumOfWeights = onLineStatisticsArr[i2].getSumOfWeights();
                        dArr2[i2] = sumOfWeights;
                        d4 += sumOfWeights;
                        d3 += onLineStatisticsArr[i2].getVarance() * onLineStatisticsArr[i2].getSumOfWeights();
                        dArr[i2] = onLineStatisticsArr[i2].getMean();
                    }
                    for (int i3 = 0; i3 < onLineStatisticsArr.length; i3++) {
                        int i4 = i3;
                        dArr2[i4] = dArr2[i4] / d4;
                    }
                    if (!arrayList3.isEmpty()) {
                        distributMissing(list3, dArr2, arrayList3);
                    }
                } else {
                    final int length = intValue - this.catAttributes.length;
                    Collections.sort(list2, new Comparator<DataPointPair<Double>>() { // from class: jsat.classifiers.trees.DecisionStump.1
                        @Override // java.util.Comparator
                        public int compare(DataPointPair<Double> dataPointPair3, DataPointPair<Double> dataPointPair4) {
                            return Double.compare(dataPointPair3.getVector().get(length), dataPointPair4.getVector().get(length));
                        }
                    });
                    OnLineStatistics onLineStatistics = new OnLineStatistics();
                    OnLineStatistics onLineStatistics2 = new OnLineStatistics();
                    int i5 = 0;
                    Iterator it2 = list2.iterator();
                    while (it2.hasNext()) {
                        DataPointPair dataPointPair3 = (DataPointPair) it2.next();
                        if (Double.isNaN(dataPointPair3.getVector().get(length))) {
                            i5++;
                        } else {
                            onLineStatistics.add(((Double) dataPointPair3.getPair()).doubleValue(), dataPointPair3.getDataPoint().getWeight());
                        }
                    }
                    int i6 = 0;
                    d3 = Double.POSITIVE_INFINITY;
                    double sumOfWeights2 = onLineStatistics.getSumOfWeights();
                    dArr = new double[3];
                    dArr2 = new double[2];
                    for (int i7 = 0; i7 < list2.size() - i5; i7++) {
                        DataPointPair dataPointPair4 = (DataPointPair) list2.get(i7);
                        double weight2 = dataPointPair4.getDataPoint().getWeight();
                        double doubleValue = ((Double) dataPointPair4.getPair()).doubleValue();
                        onLineStatistics.remove(doubleValue, weight2);
                        onLineStatistics2.add(doubleValue, weight2);
                        if (i7 >= this.minResultSplitSize) {
                            if (i7 > (list2.size() - this.minResultSplitSize) - i5) {
                                break;
                            }
                            double varance = (onLineStatistics.getVarance() * onLineStatistics.getSumOfWeights()) + (onLineStatistics2.getVarance() * onLineStatistics2.getSumOfWeights());
                            if (varance < d3 && !Double.isInfinite(varance)) {
                                d3 = varance;
                                i6 = i7;
                                dArr[0] = onLineStatistics2.getMean();
                                dArr[1] = onLineStatistics.getMean();
                                dArr[2] = (((DataPointPair) list2.get(i6)).getVector().get(length) + ((DataPointPair) list2.get(i6 + 1)).getVector().get(length)) / 2.0d;
                                dArr2[0] = onLineStatistics2.getSumOfWeights() / sumOfWeights2;
                                dArr2[1] = onLineStatistics.getSumOfWeights() / sumOfWeights2;
                            }
                        }
                    }
                    if (list2.size() - i5 >= this.minResultSplitSize) {
                        list3 = listOfListsD(2);
                        list3.get(0).addAll(list2.subList(0, i6 + 1));
                        list3.get(1).addAll(list2.subList(i6 + 1, list2.size() - i5));
                        if (i5 > 0) {
                            distributMissing(list3, dArr2, list2.subList(list2.size() - i5, list2.size()));
                        }
                    } else {
                        d3 = Double.NEGATIVE_INFINITY;
                    }
                }
                if (Math.abs(d3) < 1.0E-13d) {
                    d3 = Math.abs(d3);
                }
                if (d3 >= 0.0d && d3 < atomicDouble.get()) {
                    synchronized (arrayList2) {
                        if (d3 < atomicDouble.get()) {
                            atomicDouble.set(d3);
                            arrayList2.clear();
                            arrayList2.addAll(list3);
                            this.splittingAttribute = intValue;
                            this.regressionResults = dArr;
                            this.pathRatio = dArr2;
                        }
                    }
                }
                countDownLatch.countDown();
            });
        }
        try {
            countDownLatch.await();
            if (this.splittingAttribute < this.catAttributes.length || this.removeContinuousAttributes) {
                set.remove(Integer.valueOf(this.splittingAttribute));
            }
            if (arrayList2.size() == 0) {
                return null;
            }
            return arrayList2;
        } catch (InterruptedException e) {
            Logger.getLogger(DecisionStump.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            throw new FailedToFitException(e);
        }
    }

    private static List<List<DataPointPair<Integer>>> listOfLists(int i) {
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(new ArrayList());
        }
        return arrayList;
    }

    private static List<List<DataPointPair<Double>>> listOfListsD(int i) {
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(new ArrayList());
        }
        return arrayList;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    private ImpurityScore getClassGainScore(List<DataPointPair<Integer>> list) {
        ImpurityScore impurityScore = new ImpurityScore(this.predicting.getNumOfCategories(), this.gainMethod);
        for (DataPointPair<Integer> dataPointPair : list) {
            impurityScore.addPoint(dataPointPair.getDataPoint(), dataPointPair.getPair().intValue());
        }
        return impurityScore;
    }

    @Override // jsat.regression.Regressor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public DecisionStump m93clone() {
        DecisionStump decisionStump = new DecisionStump();
        if (this.catAttributes != null) {
            decisionStump.catAttributes = CategoricalData.copyOf(this.catAttributes);
        }
        if (this.results != null) {
            decisionStump.results = new CategoricalResults[this.results.length];
            for (int i = 0; i < this.results.length; i++) {
                decisionStump.results[i] = this.results[i].m2clone();
            }
        }
        decisionStump.removeContinuousAttributes = this.removeContinuousAttributes;
        decisionStump.splittingAttribute = this.splittingAttribute;
        if (this.boundries != null) {
            decisionStump.boundries = new DoubleList(this.boundries);
        }
        if (this.owners != null) {
            decisionStump.owners = new IntList(this.owners);
        }
        if (this.predicting != null) {
            decisionStump.predicting = this.predicting.m1clone();
        }
        if (this.regressionResults != null) {
            decisionStump.regressionResults = Arrays.copyOf(this.regressionResults, this.regressionResults.length);
        }
        if (this.pathRatio != null) {
            decisionStump.pathRatio = Arrays.copyOf(this.pathRatio, this.pathRatio.length);
        }
        decisionStump.minResultSplitSize = this.minResultSplitSize;
        decisionStump.gainMethod = this.gainMethod;
        decisionStump.numNumericFeatures = this.numNumericFeatures;
        return decisionStump;
    }
}
