/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers;

import java.util.List;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;

public class MajorityVote
implements Classifier {
    private static final long serialVersionUID = 7945429768861275845L;
    private Classifier[] voters;

    public MajorityVote(Classifier ... voters) {
        this.voters = voters;
    }

    public MajorityVote(List<Classifier> voters) {
        this.voters = voters.toArray(new Classifier[0]);
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults toReturn = null;
        for (Classifier classifier : this.voters) {
            if (classifier == null) continue;
            if (toReturn == null) {
                toReturn = classifier.classify(data);
                for (int i = 0; i < toReturn.size(); ++i) {
                    if (i != toReturn.mostLikely()) {
                        toReturn.setProb(i, 0.0);
                        continue;
                    }
                    toReturn.setProb(i, 1.0);
                }
                continue;
            }
            CategoricalResults vote = classifier.classify(data);
            for (int i = 0; i < toReturn.size(); ++i) {
                toReturn.incProb(vote.mostLikely(), 1.0);
            }
        }
        toReturn.normalize();
        return toReturn;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        for (Classifier classifier : this.voters) {
            classifier.train(dataSet, parallel);
        }
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        for (Classifier classifier : this.voters) {
            classifier.train(dataSet);
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public Classifier clone() {
        Classifier[] votersClone = new Classifier[this.voters.length];
        for (int i = 0; i < this.voters.length; ++i) {
            if (this.voters[i] == null) continue;
            votersClone[i] = this.voters[i].clone();
        }
        return new MajorityVote(this.voters);
    }
}

