/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;

public class OneVSOne
implements Classifier,
Parameterized {
    private static final long serialVersionUID = 733202830281869416L;
    @Parameter.ParameterHolder
    protected Classifier baseClassifier;
    protected Classifier[][] oneVone;
    private boolean concurrentTrain;
    protected CategoricalData predicting;

    public OneVSOne(Classifier baseClassifier) {
        this(baseClassifier, false);
    }

    public OneVSOne(Classifier baseClassifier, boolean concurrentTrain) {
        this.baseClassifier = baseClassifier;
        this.concurrentTrain = concurrentTrain;
    }

    public void setConcurrentTraining(boolean concurrentTrain) {
        this.concurrentTrain = concurrentTrain;
    }

    public boolean isConcurrentTraining() {
        return this.concurrentTrain;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(this.predicting.getNumOfCategories());
        for (int i = 0; i < this.oneVone.length; ++i) {
            for (int j = 0; j < this.oneVone[i].length; ++j) {
                CategoricalResults subRes = this.oneVone[i][j].classify(data);
                int mostLikely = subRes.mostLikely();
                if (mostLikely == 0) {
                    cr.incProb(i, 1.0);
                    continue;
                }
                cr.incProb(i + j + 1, 1.0);
            }
        }
        cr.normalize();
        return cr;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.oneVone = new Classifier[dataSet.getClassSize()][];
        ArrayList<List<DataPoint>> dataByCategory = new ArrayList<List<DataPoint>>(dataSet.getClassSize());
        for (int i = 0; i < dataSet.getClassSize(); ++i) {
            dataByCategory.add(dataSet.getSamples(i));
        }
        CountDownLatch latch = new CountDownLatch(this.oneVone.length * (this.oneVone.length - 1) / 2);
        ExecutorService threadPool = parallel ? Executors.newFixedThreadPool(SystemInfo.LogicalCores) : new FakeExecutor();
        for (int i = 0; i < this.oneVone.length; ++i) {
            this.oneVone[i] = new Classifier[this.oneVone.length - i - 1];
            for (int j = 0; j < this.oneVone.length - i - 1; ++j) {
                Classifier curClassifier;
                this.oneVone[i][j] = curClassifier = this.baseClassifier.clone();
                int otherClass = j + i + 1;
                CategoricalData subPred = new CategoricalData(2);
                subPred.setOptionName(dataSet.getPredicting().getOptionName(i), 0);
                subPred.setOptionName(dataSet.getPredicting().getOptionName(otherClass), 1);
                ClassificationDataSet subDataSet = new ClassificationDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories(), subPred);
                for (DataPoint dp : (List)dataByCategory.get(i)) {
                    subDataSet.addDataPoint(dp.getNumericalValues(), dp.getCategoricalValues(), 0);
                }
                for (DataPoint dp : (List)dataByCategory.get(otherClass)) {
                    subDataSet.addDataPoint(dp.getNumericalValues(), dp.getCategoricalValues(), 1);
                }
                if (!this.concurrentTrain) {
                    curClassifier.train(subDataSet, parallel);
                    continue;
                }
                threadPool.submit(() -> {
                    curClassifier.train(subDataSet);
                    latch.countDown();
                });
            }
        }
        if (this.concurrentTrain) {
            try {
                latch.await();
            }
            catch (InterruptedException ex) {
                Logger.getLogger(OneVSOne.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        this.predicting = dataSet.getPredicting();
        threadPool.shutdownNow();
    }

    @Override
    public boolean supportsWeightedData() {
        return this.baseClassifier.supportsWeightedData();
    }

    @Override
    public OneVSOne clone() {
        OneVSOne clone = new OneVSOne(this.baseClassifier.clone(), this.concurrentTrain);
        if (this.oneVone != null) {
            clone.oneVone = new Classifier[this.oneVone.length][];
            for (int i = 0; i < this.oneVone.length; ++i) {
                clone.oneVone[i] = new Classifier[this.oneVone[i].length];
                for (int j = 0; j < this.oneVone[i].length; ++j) {
                    clone.oneVone[i][j] = this.oneVone[i][j].clone();
                }
            }
        }
        if (this.predicting != null) {
            clone.predicting = this.predicting.clone();
        }
        return clone;
    }
}

