package jsat.classifiers.imbalance;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.vectorcollection.DefaultVectorCollection;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/classifiers/imbalance/BorderlineSMOTE.class */
public class BorderlineSMOTE extends SMOTE {
    private boolean majorityInterpolation;

    public BorderlineSMOTE(Classifier classifier) {
        this(classifier, false);
    }

    public BorderlineSMOTE(Classifier classifier, boolean z) {
        this(classifier, new EuclideanDistance(), z);
    }

    public BorderlineSMOTE(Classifier classifier, DistanceMetric distanceMetric, boolean z) {
        this(classifier, distanceMetric, 1.0d, z);
    }

    public BorderlineSMOTE(Classifier classifier, DistanceMetric distanceMetric, double d, boolean z) {
        this(classifier, distanceMetric, 5, d, z);
    }

    public BorderlineSMOTE(Classifier classifier, DistanceMetric distanceMetric, int i, double d, boolean z) {
        super(classifier, distanceMetric, i, d);
        setMajorityInterpolation(z);
    }

    public BorderlineSMOTE(BorderlineSMOTE borderlineSMOTE) {
        super((SMOTE) borderlineSMOTE);
        this.majorityInterpolation = borderlineSMOTE.majorityInterpolation;
    }

    public void setMajorityInterpolation(boolean z) {
        this.majorityInterpolation = z;
    }

    public boolean isMajorityInterpolation() {
        return this.majorityInterpolation;
    }

    @Override // jsat.classifiers.imbalance.SMOTE, jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        if (classificationDataSet.getNumCategoricalVars() != 0) {
            throw new FailedToFitException("SMOTE only works with numeric-only feature values");
        }
        List<Vec> dataVectors = classificationDataSet.getDataVectors();
        IntList[] intListArr = new IntList[classificationDataSet.getClassSize()];
        for (int i = 0; i < intListArr.length; i++) {
            intListArr[i] = new IntList();
        }
        for (int i2 = 0; i2 < classificationDataSet.getSampleSize(); i2++) {
            intListArr[classificationDataSet.getDataPointCategory(i2)].add(i2);
        }
        DenseVector mo46clone = DenseVector.toDenseVec(classificationDataSet.getPriors()).mo46clone();
        int sampleSize = (int) (classificationDataSet.getSampleSize() * mo46clone.max());
        mo46clone.mutableDivide(mo46clone.max());
        ArrayList arrayList = new ArrayList();
        DefaultVectorCollection defaultVectorCollection = new DefaultVectorCollection(this.dm, dataVectors, z);
        Iterator<Integer> it = ListUtils.range(0, classificationDataSet.getClassSize()).iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            int size = (int) ((sampleSize * this.targetRatio) - intListArr[intValue].size());
            if (size > 0) {
                ArrayList arrayList2 = new ArrayList();
                Iterator<Integer> it2 = intListArr[intValue].iterator();
                while (it2.hasNext()) {
                    arrayList2.add(dataVectors.get(it2.next().intValue()));
                }
                DefaultVectorCollection defaultVectorCollection2 = new DefaultVectorCollection(this.dm, arrayList2, z);
                ArrayList arrayList3 = new ArrayList();
                defaultVectorCollection.search(arrayList2, this.smoteNeighbors + 1, arrayList3, new ArrayList(), z);
                ArrayList arrayList4 = new ArrayList();
                if (this.majorityInterpolation) {
                    for (List<Integer> list : arrayList3) {
                        arrayList4.add(new ArrayList(this.smoteNeighbors));
                    }
                }
                IntList intList = new IntList();
                for (int i3 = 0; i3 < defaultVectorCollection2.size(); i3++) {
                    int i4 = 0;
                    List<Integer> list2 = arrayList3.get(i3);
                    for (int i5 = 1; i5 < this.smoteNeighbors + 1; i5++) {
                        if (intValue == classificationDataSet.getDataPointCategory(list2.get(i5).intValue())) {
                            i4++;
                        } else if (this.majorityInterpolation) {
                            ((List) arrayList4.get(i3)).add(defaultVectorCollection.get(list2.get(i5).intValue()));
                        }
                    }
                    double d = 1.0d - (i4 / this.smoteNeighbors);
                    if (0.5d <= d && d < 1.0d) {
                        intList.add(i3);
                    }
                }
                ArrayList arrayList5 = new ArrayList();
                defaultVectorCollection2.search(defaultVectorCollection2, this.smoteNeighbors + 1, arrayList5, new ArrayList(), z);
                ParallelUtils.run(z, size, (i6, i7) -> {
                    Vec vec;
                    Random random = RandomUtil.getRandom();
                    ArrayList arrayList6 = new ArrayList();
                    for (int i6 = i6; i6 < i7; i6++) {
                        int size2 = intList.isEmpty() ? i6 % arrayList2.size() : intList.getI(i6 % intList.size());
                        boolean z2 = random.nextBoolean() && this.majorityInterpolation && !intList.isEmpty();
                        if (z2) {
                            List list3 = (List) arrayList4.get(size2);
                            vec = (Vec) list3.get(random.nextInt(list3.size()));
                        } else {
                            vec = defaultVectorCollection2.get(((Integer) ((List) arrayList5.get(size2)).get(random.nextInt(this.smoteNeighbors) + 1)).intValue());
                        }
                        double nextDouble = random.nextDouble();
                        if (z2) {
                            nextDouble /= 2.0d;
                        }
                        Vec mo46clone2 = ((Vec) arrayList2.get(size2)).mo46clone();
                        mo46clone2.mutableMultiply(nextDouble + 1.0d);
                        mo46clone2.mutableAdd(nextDouble, vec);
                        arrayList6.add(new DataPoint(mo46clone2));
                    }
                    synchronized (arrayList) {
                        Iterator it3 = arrayList6.iterator();
                        while (it3.hasNext()) {
                            arrayList.add(new DataPointPair((DataPoint) it3.next(), Integer.valueOf(intValue)));
                        }
                    }
                });
            }
        }
        this.baseClassifier.train(new ClassificationDataSet((List<DataPointPair<Integer>>) ListUtils.mergedView(arrayList, classificationDataSet.getAsDPPList()), classificationDataSet.getPredicting()), z);
    }

    @Override // jsat.classifiers.imbalance.SMOTE
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public BorderlineSMOTE mo38clone() {
        return new BorderlineSMOTE(this);
    }
}
