package jsat.classifiers.boosting;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Semaphore;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:jsat/classifiers/boosting/Bagging.class */
public class Bagging implements Classifier, Regressor, Parameterized {
    private static final long serialVersionUID = -6566453570170428838L;
    private Classifier baseClassifier;
    private Regressor baseRegressor;
    private CategoricalData predicting;
    private int extraSamples;
    private int rounds;
    private boolean simultaniousTraining;
    private Random random;
    private List learners;
    public static final int DEFAULT_ROUNDS = 20;
    public static final int DEFAULT_EXTRA_SAMPLES = 0;
    public static final boolean DEFAULT_SIMULTANIOUS_TRAINING = true;

    public Bagging(Classifier classifier) {
        this(classifier, 0, true);
    }

    public Bagging(Classifier classifier, int i, boolean z) {
        this(classifier, i, z, 20, new Random(1L));
    }

    public Bagging(Classifier classifier, int i, boolean z, int i2, Random random) {
        this(i, z, i2, random);
        this.baseClassifier = classifier;
    }

    public Bagging(Regressor regressor) {
        this(regressor, 0, true);
    }

    public Bagging(Regressor regressor, int i, boolean z) {
        this(regressor, i, z, 20, new Random(1L));
    }

    public Bagging(Regressor regressor, int i, boolean z, int i2, Random random) {
        this(i, z, i2, random);
        this.baseRegressor = regressor;
    }

    private Bagging(int i, boolean z, int i2, Random random) {
        setExtraSamples(i);
        setSimultaniousTraining(z);
        setRounds(i2);
        this.random = random;
    }

    public void setExtraSamples(int i) {
        this.extraSamples = i;
    }

    public int getExtraSamples() {
        return this.extraSamples;
    }

    public void setRounds(int i) {
        if (i <= 0) {
            throw new ArithmeticException("Must train a positive number of learners");
        }
        this.rounds = i;
    }

    public int getRounds() {
        return this.rounds;
    }

    public void setSimultaniousTraining(boolean z) {
        this.simultaniousTraining = z;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.baseClassifier == null) {
            throw new RuntimeException("Bagging instance created for regression, not classification");
        }
        if (this.learners == null || this.learners.isEmpty()) {
            throw new RuntimeException("Classifier has not yet been trained");
        }
        CategoricalResults categoricalResults = new CategoricalResults(this.predicting.getNumOfCategories());
        for (int i = 0; i < this.learners.size(); i++) {
            categoricalResults.incProb(((Classifier) this.learners.get(i)).classify(dataPoint).mostLikely(), 1.0d);
        }
        categoricalResults.normalize();
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        this.predicting = classificationDataSet.getPredicting();
        this.learners = new ArrayList(this.rounds);
        Semaphore semaphore = new Semaphore(SystemInfo.LogicalCores);
        CountDownLatch countDownLatch = new CountDownLatch(this.rounds);
        List synchronizedList = Collections.synchronizedList(this.learners);
        int[] iArr = new int[classificationDataSet.getSampleSize()];
        ExecutorService newExecutor = ParallelUtils.getNewExecutor(z);
        for (int i = 0; i < this.rounds; i++) {
            sampleWithReplacement(iArr, iArr.length + this.extraSamples, this.random);
            ClassificationDataSet sampledDataSet = getSampledDataSet(classificationDataSet, iArr);
            Classifier clone = this.baseClassifier.clone();
            if (this.simultaniousTraining && z) {
                try {
                    semaphore.acquire();
                    newExecutor.submit(() -> {
                        clone.train(sampledDataSet);
                        synchronizedList.add(clone);
                        semaphore.release();
                        countDownLatch.countDown();
                    });
                } catch (InterruptedException e) {
                    Logger.getLogger(Bagging.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
                    System.err.println(e.getMessage());
                }
            } else {
                clone.train(sampledDataSet, z);
                this.learners.add(clone);
            }
        }
        if (this.simultaniousTraining && z) {
            try {
                countDownLatch.await();
            } catch (InterruptedException e2) {
                Logger.getLogger(Bagging.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
                System.err.println(e2.getMessage());
            }
        }
        newExecutor.shutdownNow();
    }

    public static ClassificationDataSet getSampledDataSet(ClassificationDataSet classificationDataSet, int[] iArr) {
        ClassificationDataSet classificationDataSet2 = new ClassificationDataSet(classificationDataSet.getNumNumericalVars(), classificationDataSet.getCategories(), classificationDataSet.getPredicting());
        for (int i = 0; i < iArr.length; i++) {
            for (int i2 = 0; i2 < iArr[i]; i2++) {
                DataPoint dataPoint = classificationDataSet.getDataPoint(i);
                classificationDataSet2.addDataPoint(dataPoint.getNumericalValues(), dataPoint.getCategoricalValues(), classificationDataSet.getDataPointCategory(i));
            }
        }
        return classificationDataSet2;
    }

    public static ClassificationDataSet getWeightSampledDataSet(ClassificationDataSet classificationDataSet, int[] iArr) {
        ClassificationDataSet classificationDataSet2 = new ClassificationDataSet(classificationDataSet.getNumNumericalVars(), classificationDataSet.getCategories(), classificationDataSet.getPredicting());
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] > 0) {
                DataPoint dataPoint = classificationDataSet.getDataPoint(i);
                classificationDataSet2.addDataPoint(dataPoint.getNumericalValues(), dataPoint.getCategoricalValues(), classificationDataSet.getDataPointCategory(i), dataPoint.getWeight() * iArr[i]);
            }
        }
        return classificationDataSet2;
    }

    public static RegressionDataSet getSampledDataSet(RegressionDataSet regressionDataSet, int[] iArr) {
        RegressionDataSet regressionDataSet2 = new RegressionDataSet(regressionDataSet.getNumNumericalVars(), regressionDataSet.getCategories());
        for (int i = 0; i < iArr.length; i++) {
            for (int i2 = 0; i2 < iArr[i]; i2++) {
                regressionDataSet2.addDataPoint(regressionDataSet.getDataPoint(i), regressionDataSet.getTargetValue(i));
            }
        }
        return regressionDataSet2;
    }

    public static RegressionDataSet getWeightSampledDataSet(RegressionDataSet regressionDataSet, int[] iArr) {
        RegressionDataSet regressionDataSet2 = new RegressionDataSet(regressionDataSet.getNumNumericalVars(), regressionDataSet.getCategories());
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] > 0) {
                DataPoint dataPoint = regressionDataSet.getDataPoint(i);
                regressionDataSet2.addDataPoint(new DataPoint(dataPoint.getNumericalValues(), dataPoint.getCategoricalValues(), dataPoint.getCategoricalData(), dataPoint.getWeight() * iArr[i]), regressionDataSet.getTargetValue(i));
            }
        }
        return regressionDataSet2;
    }

    public static void sampleWithReplacement(int[] iArr, int i, Random random) {
        Arrays.fill(iArr, 0);
        for (int i2 = 0; i2 < i; i2++) {
            int nextInt = random.nextInt(iArr.length);
            iArr[nextInt] = iArr[nextInt] + 1;
        }
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        if (this.baseRegressor == null) {
            throw new RuntimeException("Bagging instance created for classification, not regression");
        }
        if (this.learners == null || this.learners.isEmpty()) {
            throw new RuntimeException("Regressor has not yet been trained");
        }
        OnLineStatistics onLineStatistics = new OnLineStatistics();
        for (int i = 0; i < this.learners.size(); i++) {
            onLineStatistics.add(((Regressor) this.learners.get(i)).regress(dataPoint));
        }
        return onLineStatistics.getMean();
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, boolean z) {
        this.learners = new ArrayList(this.rounds);
        Semaphore semaphore = new Semaphore(SystemInfo.LogicalCores);
        CountDownLatch countDownLatch = new CountDownLatch(this.rounds);
        List synchronizedList = Collections.synchronizedList(this.learners);
        int[] iArr = new int[regressionDataSet.getSampleSize()];
        ExecutorService newExecutor = ParallelUtils.getNewExecutor(z);
        for (int i = 0; i < this.rounds; i++) {
            sampleWithReplacement(iArr, iArr.length + this.extraSamples, this.random);
            RegressionDataSet sampledDataSet = getSampledDataSet(regressionDataSet, iArr);
            Regressor clone = this.baseRegressor.clone();
            if (this.simultaniousTraining && z) {
                try {
                    semaphore.acquire();
                    newExecutor.submit(() -> {
                        clone.train(sampledDataSet);
                        synchronizedList.add(clone);
                        semaphore.release();
                        countDownLatch.countDown();
                    });
                } catch (InterruptedException e) {
                    Logger.getLogger(Bagging.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
                    System.err.println(e.getMessage());
                }
            } else {
                clone.train(sampledDataSet, z);
                this.learners.add(clone);
            }
        }
        if (this.simultaniousTraining && z) {
            try {
                countDownLatch.await();
            } catch (InterruptedException e2) {
                Logger.getLogger(Bagging.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
                System.err.println(e2.getMessage());
            }
        }
        newExecutor.shutdownNow();
    }

    @Override // jsat.regression.Regressor
    public Bagging clone() {
        Bagging bagging = new Bagging(this.extraSamples, this.simultaniousTraining, this.rounds, new Random(this.rounds));
        if (this.baseClassifier != null) {
            bagging.baseClassifier = this.baseClassifier.clone();
        }
        if (this.predicting != null) {
            bagging.predicting = this.predicting.m1clone();
        }
        if (this.baseRegressor != null) {
            bagging.baseRegressor = this.baseRegressor.clone();
        }
        if (this.learners != null && !this.learners.isEmpty()) {
            bagging.learners = new ArrayList(this.learners.size());
            for (Object obj : this.learners) {
                if (obj instanceof Classifier) {
                    bagging.learners.add(((Classifier) obj).clone());
                } else {
                    bagging.learners.add(((Regressor) obj).clone());
                }
            }
        }
        return bagging;
    }
}
