package jsat.classifiers.boosting;

import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.distributions.ContinuousDistribution;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/classifiers/boosting/Wagging.class */
public class Wagging implements Classifier, Regressor, Parameterized {
    private static final long serialVersionUID = 4999034730848794619L;
    private ContinuousDistribution dist;
    private int iterations;
    private Classifier weakL;
    private Regressor weakR;
    private CategoricalData predicting;
    private Classifier[] hypotsL;
    private Regressor[] hypotsR;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:jsat/classifiers/boosting/Wagging$WagFill.class */
    public class WagFill implements Runnable {
        int start;
        int end;
        DataSet ds;
        Random rand;
        CountDownLatch latch;

        public WagFill(int i, int i2, DataSet dataSet, Random random, CountDownLatch countDownLatch) {
            this.start = i;
            this.end = i2;
            this.ds = dataSet.shallowClone2();
            this.rand = random;
            this.latch = countDownLatch;
            for (int i3 = 0; i3 < this.ds.getSampleSize(); i3++) {
                DataPoint dataPoint = this.ds.getDataPoint(i3);
                this.ds.setDataPoint(i3, new DataPoint(dataPoint.getNumericalValues(), dataPoint.getCategoricalValues(), dataPoint.getCategoricalData()));
            }
        }

        @Override // java.lang.Runnable
        public void run() {
            if (this.ds instanceof ClassificationDataSet) {
                ClassificationDataSet classificationDataSet = (ClassificationDataSet) this.ds;
                for (int i = this.start; i < this.end; i++) {
                    for (int i2 = 0; i2 < this.ds.getSampleSize(); i2++) {
                        this.ds.getDataPoint(i2).setWeight(Math.max(1.0E-6d, Wagging.this.dist.invCdf(this.rand.nextDouble())));
                    }
                    Classifier clone = Wagging.this.weakL.clone();
                    clone.train(classificationDataSet);
                    Wagging.this.hypotsL[i] = clone;
                }
            } else {
                if (!(this.ds instanceof RegressionDataSet)) {
                    throw new RuntimeException("BUG: please report");
                }
                RegressionDataSet regressionDataSet = (RegressionDataSet) this.ds;
                for (int i3 = this.start; i3 < this.end; i3++) {
                    for (int i4 = 0; i4 < this.ds.getSampleSize(); i4++) {
                        this.ds.getDataPoint(i3).setWeight(Math.max(1.0E-6d, Wagging.this.dist.invCdf(this.rand.nextDouble())));
                    }
                    Regressor clone2 = Wagging.this.weakR.clone();
                    clone2.train(regressionDataSet);
                    Wagging.this.hypotsR[i3] = clone2;
                }
            }
            this.latch.countDown();
        }
    }

    public Wagging(ContinuousDistribution continuousDistribution, Classifier classifier, int i) {
        setDistribution(continuousDistribution);
        setIterations(i);
        setWeakLearner(classifier);
    }

    public Wagging(ContinuousDistribution continuousDistribution, Regressor regressor, int i) {
        setDistribution(continuousDistribution);
        setIterations(i);
        setWeakLearner(regressor);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Wagging(Wagging wagging) {
        this.dist = wagging.dist.mo146clone();
        this.iterations = wagging.iterations;
        if (wagging.weakL != null) {
            setWeakLearner(wagging.weakL.clone());
        }
        if (wagging.weakR != null) {
            setWeakLearner(wagging.weakR.clone());
        }
        if (wagging.predicting != null) {
            this.predicting = wagging.predicting.m1clone();
        }
        if (wagging.hypotsL != null) {
            this.hypotsL = new Classifier[wagging.hypotsL.length];
            for (int i = 0; i < this.hypotsL.length; i++) {
                this.hypotsL[i] = wagging.hypotsL[i].clone();
            }
        }
        if (wagging.hypotsR != null) {
            this.hypotsR = new Regressor[wagging.hypotsR.length];
            for (int i2 = 0; i2 < this.hypotsR.length; i2++) {
                this.hypotsR[i2] = wagging.hypotsR[i2].clone();
            }
        }
    }

    public void setWeakLearner(Classifier classifier) {
        if (classifier == null) {
            throw new NullPointerException();
        }
        this.weakL = classifier;
        if (classifier instanceof Regressor) {
            this.weakR = (Regressor) classifier;
        }
    }

    public Classifier getWeakClassifier() {
        return this.weakL;
    }

    public void setWeakLearner(Regressor regressor) {
        if (regressor == null) {
            throw new NullPointerException();
        }
        this.weakR = regressor;
        if (regressor instanceof Classifier) {
            this.weakL = (Classifier) regressor;
        }
    }

    public Regressor getWeakRegressor() {
        return this.weakR;
    }

    public void setIterations(int i) {
        if (i < 1) {
            throw new ArithmeticException("The number of iterations must be positive");
        }
        this.iterations = i;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setDistribution(ContinuousDistribution continuousDistribution) {
        if (continuousDistribution == null) {
            throw new NullPointerException();
        }
        this.dist = continuousDistribution;
    }

    public ContinuousDistribution getDistribution() {
        return this.dist;
    }

    private void performTraining(boolean z, DataSet dataSet) {
        ExecutorService newExecutor = ParallelUtils.getNewExecutor(z);
        int i = this.iterations / SystemInfo.LogicalCores;
        int i2 = this.iterations % SystemInfo.LogicalCores;
        int i3 = 0;
        Random random = RandomUtil.getRandom();
        CountDownLatch countDownLatch = new CountDownLatch(i > 0 ? SystemInfo.LogicalCores : i2);
        while (i3 < this.iterations) {
            int i4 = i3;
            int i5 = i4 + i;
            int i6 = i2;
            i2--;
            if (i6 > 0) {
                i5++;
            }
            i3 = i5;
            newExecutor.submit(new WagFill(i4, i5, dataSet, new Random(random.nextInt()), countDownLatch));
        }
        try {
            try {
                countDownLatch.await();
                newExecutor.shutdownNow();
            } catch (InterruptedException e) {
                throw new FailedToFitException(e);
            }
        } catch (Throwable th) {
            newExecutor.shutdownNow();
            throw th;
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.hypotsL == null) {
            throw new UntrainedModelException("Model has not been trained for classification");
        }
        CategoricalResults categoricalResults = new CategoricalResults(this.predicting.getNumOfCategories());
        for (Classifier classifier : this.hypotsL) {
            categoricalResults.incProb(classifier.classify(dataPoint).mostLikely(), 1.0d);
        }
        categoricalResults.normalize();
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        if (this.weakL == null) {
            throw new FailedToFitException("No classification weak learner was provided");
        }
        this.predicting = classificationDataSet.getPredicting();
        this.hypotsL = new Classifier[this.iterations];
        this.hypotsR = null;
        performTraining(z, classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        if (this.hypotsR == null) {
            throw new UntrainedModelException("Model has not been trained for regression");
        }
        double d = 0.0d;
        for (Regressor regressor : this.hypotsR) {
            d += regressor.regress(dataPoint);
        }
        return d / this.hypotsR.length;
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, boolean z) {
        if (this.weakR == null) {
            throw new FailedToFitException("No regression weak learner was provided");
        }
        this.hypotsL = null;
        this.hypotsR = new Regressor[this.iterations];
        performTraining(z, regressionDataSet);
    }

    @Override // jsat.regression.Regressor
    public Wagging clone() {
        return new Wagging(this);
    }
}
