package jsat.classifiers.imbalance;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.vectorcollection.DefaultVectorCollection;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/classifiers/imbalance/SMOTE.class */
public class SMOTE implements Classifier, Parameterized {

    @Parameter.ParameterHolder
    protected Classifier baseClassifier;
    protected DistanceMetric dm;
    protected int smoteNeighbors;
    protected double targetRatio;

    public SMOTE(Classifier classifier) {
        this(classifier, new EuclideanDistance());
    }

    public SMOTE(Classifier classifier, DistanceMetric distanceMetric) {
        this(classifier, distanceMetric, 1.0d);
    }

    public SMOTE(Classifier classifier, DistanceMetric distanceMetric, double d) {
        this(classifier, distanceMetric, 5, d);
    }

    public SMOTE(Classifier classifier, DistanceMetric distanceMetric, int i, double d) {
        setBaseClassifier(classifier);
        setDistanceMetric(distanceMetric);
        setSmoteNeighbors(i);
        setTargetRatio(d);
    }

    public SMOTE(SMOTE smote) {
        this.baseClassifier = smote.baseClassifier.mo38clone();
        this.dm = smote.dm.mo185clone();
        this.smoteNeighbors = smote.smoteNeighbors;
        this.targetRatio = smote.targetRatio;
    }

    public void setDistanceMetric(DistanceMetric distanceMetric) {
        this.dm = distanceMetric;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setSmoteNeighbors(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("number of neighbors considered must be a positive value");
        }
        this.smoteNeighbors = i;
    }

    public int getSmoteNeighbors() {
        return this.smoteNeighbors;
    }

    public void setTargetRatio(double d) {
        this.targetRatio = d;
    }

    public double getTargetRatio() {
        return this.targetRatio;
    }

    public void setBaseClassifier(Classifier classifier) {
        this.baseClassifier = classifier;
    }

    public Classifier getBaseClassifier() {
        return this.baseClassifier;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        return this.baseClassifier.classify(dataPoint);
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        if (classificationDataSet.getNumCategoricalVars() != 0) {
            throw new FailedToFitException("SMOTE only works with numeric-only feature values");
        }
        List<Vec> dataVectors = classificationDataSet.getDataVectors();
        IntList[] intListArr = new IntList[classificationDataSet.getClassSize()];
        for (int i = 0; i < intListArr.length; i++) {
            intListArr[i] = new IntList();
        }
        for (int i2 = 0; i2 < classificationDataSet.getSampleSize(); i2++) {
            intListArr[classificationDataSet.getDataPointCategory(i2)].add(i2);
        }
        DenseVector mo46clone = DenseVector.toDenseVec(classificationDataSet.getPriors()).mo46clone();
        int sampleSize = (int) (classificationDataSet.getSampleSize() * mo46clone.max());
        mo46clone.mutableDivide(mo46clone.max());
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it = ListUtils.range(0, classificationDataSet.getClassSize()).iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            int size = (int) ((sampleSize * this.targetRatio) - intListArr[intValue].size());
            if (size > 0) {
                ArrayList arrayList2 = new ArrayList();
                Iterator<Integer> it2 = intListArr[intValue].iterator();
                while (it2.hasNext()) {
                    arrayList2.add(dataVectors.get(it2.next().intValue()));
                }
                DefaultVectorCollection defaultVectorCollection = new DefaultVectorCollection(this.dm, arrayList2, z);
                ArrayList arrayList3 = new ArrayList();
                defaultVectorCollection.search(defaultVectorCollection, this.smoteNeighbors + 1, arrayList3, new ArrayList(), z);
                ParallelUtils.run(z, size, (i3, i4) -> {
                    Random random = RandomUtil.getRandom();
                    ArrayList arrayList4 = new ArrayList();
                    for (int i3 = i3; i3 < i4; i3++) {
                        int size2 = i3 % arrayList2.size();
                        Vec vec = defaultVectorCollection.get(((Integer) ((List) arrayList3.get(size2)).get(random.nextInt(this.smoteNeighbors) + 1)).intValue());
                        double nextDouble = random.nextDouble();
                        Vec mo46clone2 = ((Vec) arrayList2.get(size2)).mo46clone();
                        mo46clone2.mutableMultiply(nextDouble + 1.0d);
                        mo46clone2.mutableAdd(nextDouble, vec);
                        arrayList4.add(new DataPoint(mo46clone2));
                    }
                    synchronized (arrayList) {
                        Iterator it3 = arrayList4.iterator();
                        while (it3.hasNext()) {
                            arrayList.add(new DataPointPair((DataPoint) it3.next(), Integer.valueOf(intValue)));
                        }
                    }
                });
            }
        }
        this.baseClassifier.train(new ClassificationDataSet((List<DataPointPair<Integer>>) ListUtils.mergedView(arrayList, classificationDataSet.getAsDPPList()), classificationDataSet.getPredicting()), z);
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet) {
        train(classificationDataSet, false);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public SMOTE mo38clone() {
        return new SMOTE(this);
    }
}
