package jsat.classifiers.boosting;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;

/* loaded from: input_file:jsat/classifiers/boosting/ModestAdaBoost.class */
public class ModestAdaBoost implements Classifier, Parameterized, BinaryScoreClassifier {
    private static final long serialVersionUID = 8223388561185098909L;
    private Classifier weakLearner;
    private int maxIterations;
    protected List<Classifier> hypoths;
    protected List<Double> hypWeights;
    protected CategoricalData predicting;

    public ModestAdaBoost(Classifier classifier, int i) {
        setWeakLearner(classifier);
        setMaxIterations(i);
    }

    protected ModestAdaBoost(ModestAdaBoost modestAdaBoost) {
        this(modestAdaBoost.weakLearner.m25clone(), modestAdaBoost.maxIterations);
        if (modestAdaBoost.hypWeights != null) {
            this.hypWeights = new DoubleList(modestAdaBoost.hypWeights);
            this.hypoths = new ArrayList(modestAdaBoost.maxIterations);
            Iterator<Classifier> it = modestAdaBoost.hypoths.iterator();
            while (it.hasNext()) {
                this.hypoths.add(it.next().m25clone());
            }
            this.predicting = modestAdaBoost.predicting.m1clone();
        }
    }

    public List<Classifier> getModels() {
        return Collections.unmodifiableList(this.hypoths);
    }

    public List<Double> getModelWeights() {
        return Collections.unmodifiableList(this.hypWeights);
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Iterations must be positive, not " + i);
        }
        this.maxIterations = i;
    }

    public Classifier getWeakLearner() {
        return this.weakLearner;
    }

    public void setWeakLearner(Classifier classifier) {
        if (!classifier.supportsWeightedData()) {
            throw new IllegalArgumentException("WeakLearner must support weighted data to be boosted");
        }
        this.weakLearner = classifier;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        double d = 0.0d;
        for (int i = 0; i < this.hypoths.size(); i++) {
            d += ((this.hypoths.get(i).classify(dataPoint).getProb(1) * 2.0d) - 1.0d) * this.hypWeights.get(i).doubleValue();
        }
        return d;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.predicting == null) {
            throw new RuntimeException("Classifier has not been trained yet");
        }
        CategoricalResults categoricalResults = new CategoricalResults(this.predicting.getNumOfCategories());
        if (getScore(dataPoint) < 0.0d) {
            categoricalResults.setProb(0, 1.0d);
        } else {
            categoricalResults.setProb(1, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        this.predicting = classificationDataSet.getPredicting();
        this.hypWeights = new DoubleList(this.maxIterations);
        this.hypoths = new ArrayList(this.maxIterations);
        int sampleSize = classificationDataSet.getSampleSize();
        double[] dArr = new double[sampleSize];
        double[] dArr2 = new double[sampleSize];
        List<DataPointPair<Integer>> asDPPList = classificationDataSet.getTwiceShallowClone().getAsDPPList();
        Arrays.fill(dArr2, 1.0d / sampleSize);
        Iterator<DataPointPair<Integer>> it = asDPPList.iterator();
        while (it.hasNext()) {
            it.next().getDataPoint().setWeight(dArr2[0]);
        }
        double[] dArr3 = new double[sampleSize];
        for (int i = 0; i < this.maxIterations; i++) {
            Classifier m25clone = this.weakLearner.m25clone();
            m25clone.train(new ClassificationDataSet(asDPPList, this.predicting), z);
            double d = 0.0d;
            for (int i2 = 0; i2 < sampleSize; i2++) {
                double d2 = 1.0d - dArr2[i2];
                dArr[i2] = d2;
                d += d2;
            }
            for (int i3 = 0; i3 < sampleSize; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] / d;
            }
            double d3 = 0.0d;
            double d4 = 0.0d;
            double d5 = 0.0d;
            double d6 = 0.0d;
            for (int i5 = 0; i5 < sampleSize; i5++) {
                DataPointPair<Integer> dataPointPair = asDPPList.get(i5);
                dArr3[i5] = (m25clone.classify(dataPointPair.getDataPoint()).getProb(1) * 2.0d) - 1.0d;
                double signum = Math.signum(dArr3[i5]);
                if (dataPointPair.getPair().intValue() == 1) {
                    d3 += signum * dArr2[i5];
                    d4 += signum * dArr[i5];
                } else {
                    d5 += signum * dArr2[i5];
                    d6 += signum * dArr[i5];
                }
            }
            double d7 = (d3 * (1.0d - d4)) - (d5 * (1.0d - d6));
            if (Math.signum(d7) != Math.signum(d3 - d5) || Math.abs(d3 - d5) < 1.0E-6d || d7 <= 0.0d) {
                return;
            }
            double d8 = 0.0d;
            for (int i6 = 0; i6 < sampleSize; i6++) {
                DataPoint dataPoint = asDPPList.get(i6).getDataPoint();
                double weight = dataPoint.getWeight() * Math.exp((-((asDPPList.get(i6).getPair().intValue() * 2) - 1)) * d7 * dArr3[i6]);
                if (Double.isInfinite(weight)) {
                    weight = 1.0d;
                } else if (weight <= 0.0d) {
                    weight = 0.001d / sampleSize;
                }
                d8 += weight;
                dataPoint.setWeight(weight);
            }
            for (int i7 = 0; i7 < sampleSize; i7++) {
                DataPoint dataPoint2 = asDPPList.get(i7).getDataPoint();
                dataPoint2.setWeight(Math.max(dataPoint2.getWeight() / d8, 1.0E-10d));
            }
            this.hypWeights.add(Double.valueOf(d7));
            this.hypoths.add(m25clone);
        }
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public ModestAdaBoost m25clone() {
        return new ModestAdaBoost(this);
    }
}
