package jsat.classifiers;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;

/* loaded from: input_file:jsat/classifiers/OneVSOne.class */
public class OneVSOne implements Classifier, Parameterized {
    private static final long serialVersionUID = 733202830281869416L;

    @Parameter.ParameterHolder
    protected Classifier baseClassifier;
    protected Classifier[][] oneVone;
    private boolean concurrentTrain;
    protected CategoricalData predicting;

    public OneVSOne(Classifier classifier) {
        this(classifier, false);
    }

    public OneVSOne(Classifier classifier, boolean z) {
        this.baseClassifier = classifier;
        this.concurrentTrain = z;
    }

    public void setConcurrentTraining(boolean z) {
        this.concurrentTrain = z;
    }

    public boolean isConcurrentTraining() {
        return this.concurrentTrain;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(this.predicting.getNumOfCategories());
        for (int i = 0; i < this.oneVone.length; i++) {
            for (int i2 = 0; i2 < this.oneVone[i].length; i2++) {
                if (this.oneVone[i][i2].classify(dataPoint).mostLikely() == 0) {
                    categoricalResults.incProb(i, 1.0d);
                } else {
                    categoricalResults.incProb(i + i2 + 1, 1.0d);
                }
            }
        }
        categoricalResults.normalize();
        return categoricalResults;
    }

    /* JADX WARN: Type inference failed for: r1v2, types: [jsat.classifiers.Classifier[], jsat.classifiers.Classifier[][]] */
    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        this.oneVone = new Classifier[classificationDataSet.getClassSize()];
        ArrayList arrayList = new ArrayList(classificationDataSet.getClassSize());
        for (int i = 0; i < classificationDataSet.getClassSize(); i++) {
            arrayList.add(classificationDataSet.getSamples(i));
        }
        CountDownLatch countDownLatch = new CountDownLatch((this.oneVone.length * (this.oneVone.length - 1)) / 2);
        ExecutorService newFixedThreadPool = z ? Executors.newFixedThreadPool(SystemInfo.LogicalCores) : new FakeExecutor();
        for (int i2 = 0; i2 < this.oneVone.length; i2++) {
            this.oneVone[i2] = new Classifier[(this.oneVone.length - i2) - 1];
            for (int i3 = 0; i3 < (this.oneVone.length - i2) - 1; i3++) {
                Classifier mo4clone = this.baseClassifier.mo4clone();
                this.oneVone[i2][i3] = mo4clone;
                int i4 = i3 + i2 + 1;
                CategoricalData categoricalData = new CategoricalData(2);
                categoricalData.setOptionName(classificationDataSet.getPredicting().getOptionName(i2), 0);
                categoricalData.setOptionName(classificationDataSet.getPredicting().getOptionName(i4), 1);
                ClassificationDataSet classificationDataSet2 = new ClassificationDataSet(classificationDataSet.getNumNumericalVars(), classificationDataSet.getCategories(), categoricalData);
                for (DataPoint dataPoint : (List) arrayList.get(i2)) {
                    classificationDataSet2.addDataPoint(dataPoint.getNumericalValues(), dataPoint.getCategoricalValues(), 0);
                }
                for (DataPoint dataPoint2 : (List) arrayList.get(i4)) {
                    classificationDataSet2.addDataPoint(dataPoint2.getNumericalValues(), dataPoint2.getCategoricalValues(), 1);
                }
                if (this.concurrentTrain) {
                    newFixedThreadPool.submit(() -> {
                        mo4clone.train(classificationDataSet2);
                        countDownLatch.countDown();
                    });
                } else {
                    mo4clone.train(classificationDataSet2, z);
                }
            }
        }
        if (this.concurrentTrain) {
            try {
                countDownLatch.await();
            } catch (InterruptedException e) {
                Logger.getLogger(OneVSOne.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            }
        }
        this.predicting = classificationDataSet.getPredicting();
        newFixedThreadPool.shutdownNow();
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return this.baseClassifier.supportsWeightedData();
    }

    /* JADX WARN: Type inference failed for: r1v7, types: [jsat.classifiers.Classifier[], jsat.classifiers.Classifier[][]] */
    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public OneVSOne mo4clone() {
        OneVSOne oneVSOne = new OneVSOne(this.baseClassifier.mo4clone(), this.concurrentTrain);
        if (this.oneVone != null) {
            oneVSOne.oneVone = new Classifier[this.oneVone.length];
            for (int i = 0; i < this.oneVone.length; i++) {
                oneVSOne.oneVone[i] = new Classifier[this.oneVone[i].length];
                for (int i2 = 0; i2 < this.oneVone[i].length; i2++) {
                    oneVSOne.oneVone[i][i2] = this.oneVone[i][i2].mo4clone();
                }
            }
        }
        if (this.predicting != null) {
            oneVSOne.predicting = this.predicting.m1clone();
        }
        return oneVSOne;
    }
}
