package jsat.classifiers.svm;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.Callable;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.parameters.Parameterized;
import jsat.utils.ListUtils;
import jsat.utils.concurrent.AtomicDouble;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/classifiers/svm/PegasosK.class */
public class PegasosK extends SupportVectorLearner implements BinaryScoreClassifier, Parameterized {
    private static final long serialVersionUID = 5405460830472328107L;
    private double regularization;
    private int iterations;

    /* loaded from: input_file:jsat/classifiers/svm/PegasosK$PredictPart.class */
    private class PredictPart implements Callable<Double> {
        int i;
        int start;
        int end;
        int[] sign;

        public PredictPart(int i, int i2, int i3, int[] iArr) {
            this.i = i;
            this.start = i2;
            this.end = i3;
            this.sign = iArr;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Double call() throws Exception {
            double d = this.sign[this.i];
            double d2 = 0.0d;
            for (int i = this.start; i < this.end; i++) {
                if (i != this.i && PegasosK.this.alphas[i] != 0.0d) {
                    d2 += PegasosK.this.alphas[i] * d * PegasosK.this.kEval(this.i, i);
                }
            }
            return Double.valueOf(d2);
        }
    }

    public PegasosK(double d, int i, KernelTrick kernelTrick) {
        this(d, i, kernelTrick, SupportVectorLearner.CacheMode.NONE);
    }

    public PegasosK(double d, int i, KernelTrick kernelTrick, SupportVectorLearner.CacheMode cacheMode) {
        super(kernelTrick, cacheMode);
        setRegularization(d);
        setIterations(i);
    }

    public void setIterations(int i) {
        this.iterations = i;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setRegularization(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d) || d <= 0.0d) {
            throw new ArithmeticException("Regularization must be a positive constant, not " + d);
        }
        this.regularization = d;
    }

    public double getRegularization() {
        return this.regularization;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier, jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public PegasosK m98clone() {
        PegasosK pegasosK = new PegasosK(this.regularization, this.iterations, getKernel().mo154clone(), getCacheMode());
        if (this.vecs != null) {
            pegasosK.vecs = new ArrayList(this.vecs);
            pegasosK.alphas = new double[this.alphas.length];
            for (int i = 0; i < this.vecs.size(); i++) {
                pegasosK.vecs.set(i, this.vecs.get(i).mo46clone());
                pegasosK.alphas[i] = this.alphas[i];
            }
        }
        return pegasosK;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.alphas == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (getScore(dataPoint) > 0.0d) {
            categoricalResults.setProb(1, 1.0d);
        } else {
            categoricalResults.setProb(0, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return kEvalSum(dataPoint.getNumericalValues());
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        if (classificationDataSet.getClassSize() != 2) {
            throw new FailedToFitException("Pegasos only supports binary classification problems");
        }
        Random random = RandomUtil.getRandom();
        int sampleSize = classificationDataSet.getSampleSize();
        this.alphas = new double[sampleSize];
        int[] iArr = new int[sampleSize];
        this.vecs = new ArrayList(sampleSize);
        for (int i = 0; i < classificationDataSet.getSampleSize(); i++) {
            this.vecs.add(classificationDataSet.getDataPoint(i).getNumericalValues());
            iArr[i] = classificationDataSet.getDataPointCategory(i) == 1 ? 1 : -1;
        }
        setCacheMode(getCacheMode());
        for (int i2 = 1; i2 <= this.iterations; i2++) {
            int nextInt = random.nextInt(sampleSize);
            double d = iArr[nextInt];
            AtomicDouble atomicDouble = new AtomicDouble(0.0d);
            ParallelUtils.run(true, sampleSize, (i3, i4) -> {
                double d2 = 0.0d;
                for (int i3 = i3; i3 < i4; i3++) {
                    if (i3 != nextInt && this.alphas[i3] != 0.0d) {
                        d2 += this.alphas[i3] * d * kEval(nextInt, i3);
                    }
                }
                atomicDouble.addAndGet(d2);
            });
            atomicDouble.set((atomicDouble.get() * d) / (this.regularization * i2));
            if (atomicDouble.get() < 1.0d) {
                double[] dArr = this.alphas;
                dArr[nextInt] = dArr[nextInt] + 1.0d;
            }
        }
        int i5 = 0;
        for (int i6 = 0; i6 < this.alphas.length; i6++) {
            if (this.alphas[i6] != 0.0d) {
                this.alphas[i5] = this.alphas[i6] * iArr[i6];
                ListUtils.swap(this.vecs, i5, i6);
                i5++;
            }
        }
        this.alphas = Arrays.copyOf(this.alphas, i5);
        this.vecs = new ArrayList(this.vecs.subList(0, i5));
        setCacheMode(null);
        setAlphas(this.alphas);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }
}
