/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.imbalance;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.vectorcollection.DefaultVectorCollection;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class SMOTE
implements Classifier,
Parameterized {
    @Parameter.ParameterHolder
    protected Classifier baseClassifier;
    protected DistanceMetric dm;
    protected int smoteNeighbors;
    protected double targetRatio;

    public SMOTE(Classifier baseClassifier) {
        this(baseClassifier, new EuclideanDistance());
    }

    public SMOTE(Classifier baseClassifier, DistanceMetric dm) {
        this(baseClassifier, dm, 1.0);
    }

    public SMOTE(Classifier baseClassifier, DistanceMetric dm, double targetRatio) {
        this(baseClassifier, dm, 5, targetRatio);
    }

    public SMOTE(Classifier baseClassifier, DistanceMetric dm, int smoteNeighbors, double targetRatio) {
        this.setBaseClassifier(baseClassifier);
        this.setDistanceMetric(dm);
        this.setSmoteNeighbors(smoteNeighbors);
        this.setTargetRatio(targetRatio);
    }

    public SMOTE(SMOTE toCopy) {
        this.baseClassifier = toCopy.baseClassifier.clone();
        this.dm = toCopy.dm.clone();
        this.smoteNeighbors = toCopy.smoteNeighbors;
        this.targetRatio = toCopy.targetRatio;
    }

    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setSmoteNeighbors(int smoteNeighbors) {
        if (smoteNeighbors < 1) {
            throw new IllegalArgumentException("number of neighbors considered must be a positive value");
        }
        this.smoteNeighbors = smoteNeighbors;
    }

    public int getSmoteNeighbors() {
        return this.smoteNeighbors;
    }

    public void setTargetRatio(double targetRatio) {
        this.targetRatio = targetRatio;
    }

    public double getTargetRatio() {
        return this.targetRatio;
    }

    public void setBaseClassifier(Classifier baseClassifier) {
        this.baseClassifier = baseClassifier;
    }

    public Classifier getBaseClassifier() {
        return this.baseClassifier;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        return this.baseClassifier.classify(data);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        int i;
        if (dataSet.getNumCategoricalVars() != 0) {
            throw new FailedToFitException("SMOTE only works with numeric-only feature values");
        }
        List<Vec> vAll = dataSet.getDataVectors();
        IntList[] classIndex = new IntList[dataSet.getClassSize()];
        for (i = 0; i < classIndex.length; ++i) {
            classIndex[i] = new IntList();
        }
        for (i = 0; i < dataSet.getSampleSize(); ++i) {
            classIndex[dataSet.getDataPointCategory(i)].add(i);
        }
        double[] priors = dataSet.getPriors();
        DenseVector ratios = DenseVector.toDenseVec(priors).clone();
        int majorityNum = (int)((double)dataSet.getSampleSize() * ((Vec)ratios).max());
        ((Vec)ratios).mutableDivide(((Vec)ratios).max());
        ArrayList synthetics = new ArrayList();
        Iterator iterator = ListUtils.range(0, dataSet.getClassSize()).iterator();
        while (iterator.hasNext()) {
            int classID = (Integer)iterator.next();
            int samplesNeeded = (int)((double)majorityNum * this.targetRatio - (double)classIndex[classID].size());
            if (samplesNeeded <= 0) continue;
            ArrayList<Vec> V_id = new ArrayList<Vec>();
            Iterator iterator2 = classIndex[classID].iterator();
            while (iterator2.hasNext()) {
                int i2 = (Integer)iterator2.next();
                V_id.add(vAll.get(i2));
            }
            DefaultVectorCollection VC_id = new DefaultVectorCollection(this.dm, V_id, parallel);
            ArrayList<List<Integer>> neighbors = new ArrayList<List<Integer>>();
            ArrayList<List<Double>> distances = new ArrayList<List<Double>>();
            VC_id.search(VC_id, this.smoteNeighbors + 1, neighbors, distances, parallel);
            ParallelUtils.run(parallel, samplesNeeded, (start, end) -> {
                Random rand = RandomUtil.getRandom();
                ArrayList<DataPoint> local_new = new ArrayList<DataPoint>();
                for (int i = start; i < end; ++i) {
                    int sampleIndex = i % V_id.size();
                    int nn = rand.nextInt(this.smoteNeighbors) + 1;
                    Object vec_nn = VC_id.get((Integer)((List)neighbors.get(sampleIndex)).get(nn));
                    double gap = rand.nextDouble();
                    Vec newVal = ((Vec)V_id.get(sampleIndex)).clone();
                    newVal.mutableMultiply(gap + 1.0);
                    newVal.mutableAdd(gap, (Vec)vec_nn);
                    local_new.add(new DataPoint(newVal));
                }
                List list = synthetics;
                synchronized (list) {
                    for (DataPoint v : local_new) {
                        synthetics.add(new DataPointPair<Integer>(v, classID));
                    }
                }
            });
        }
        ClassificationDataSet newDataSet = new ClassificationDataSet(ListUtils.mergedView(synthetics, dataSet.getAsDPPList()), dataSet.getPredicting());
        this.baseClassifier.train(newDataSet, parallel);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        this.train(dataSet, false);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public SMOTE clone() {
        return new SMOTE(this);
    }
}

