package jsat.classifiers.svm.extended;

import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import jsat.classifiers.ClassificationDataSet;
import jsat.linear.Vec;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/classifiers/svm/extended/AMM.class */
public class AMM extends OnlineAMM {
    private static final long serialVersionUID = -9198419566231617395L;
    private int subEpochs;

    public AMM() {
        this(0.01d);
    }

    public AMM(double d) {
        this(d, 50);
    }

    public AMM(double d, int i) {
        super(d, i);
        this.subEpochs = 1;
        setEpochs(10);
    }

    public AMM(AMM amm) {
        super(amm);
        this.subEpochs = 1;
        this.subEpochs = amm.subEpochs;
    }

    public void setSubEpochs(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("subEpochs must be positive, not " + i);
        }
        this.subEpochs = i;
    }

    public int getSubEpochs() {
        return this.subEpochs;
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier, jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        train(classificationDataSet);
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier, jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet) {
        IntList intList = new IntList(classificationDataSet.getSampleSize());
        ListUtils.addRange(intList, 0, classificationDataSet.getSampleSize(), 1);
        Random random = RandomUtil.getRandom();
        int[] iArr = new int[intList.size()];
        setUp(classificationDataSet.getCategories(), classificationDataSet.getNumNumericalVars(), classificationDataSet.getPredicting());
        Collections.shuffle(intList, random);
        Iterator<Integer> it = intList.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            iArr[intValue] = update(classificationDataSet.getDataPoint(intValue), classificationDataSet.getDataPointCategory(intValue), Integer.MIN_VALUE);
        }
        this.time = 1;
        int i = 0;
        do {
            for (int i2 = 0; i2 < this.subEpochs; i2++) {
                Collections.shuffle(intList, random);
                Iterator<Integer> it2 = intList.iterator();
                while (it2.hasNext()) {
                    int intValue2 = it2.next().intValue();
                    iArr[intValue2] = update(classificationDataSet.getDataPoint(intValue2), classificationDataSet.getDataPointCategory(intValue2), iArr[intValue2]);
                }
            }
            int i3 = 0;
            for (int i4 = 0; i4 < intList.size(); i4++) {
                Vec numericalValues = classificationDataSet.getDataPoint(i4).getNumericalValues();
                double d = 0.0d;
                int i5 = -1;
                for (Map.Entry<Integer, Vec> entry : this.weightMatrix.get(classificationDataSet.getDataPointCategory(i4)).entrySet()) {
                    double dot = numericalValues.dot(entry.getValue());
                    if (dot >= d) {
                        i5 = entry.getKey().intValue();
                        d = dot;
                    }
                }
                if (iArr[i4] != i5) {
                    i3++;
                    iArr[i4] = i5;
                }
            }
            if (i3 == 0) {
                return;
            } else {
                i++;
            }
        } while (i < getEpochs());
    }

    @Override // jsat.classifiers.svm.extended.OnlineAMM, jsat.classifiers.BaseUpdateableClassifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public AMM mo0clone() {
        return new AMM(this);
    }
}
