package jsat.classifiers.svm;

import java.util.Arrays;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.stream.IntStream;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.WarmClassifier;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.distributions.Distribution;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.LinearKernel;
import jsat.exceptions.FailedToFitException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:jsat/classifiers/svm/LSSVM.class */
public class LSSVM extends SupportVectorLearner implements BinaryScoreClassifier, Regressor, Parameterized, WarmRegressor, WarmClassifier {
    private static final long serialVersionUID = -7569924400631719451L;
    protected double b;
    protected double b_low;
    protected double b_up;
    private double C;
    private int i_up;
    private int i_low;
    private double[] fcache;
    private double dualObjective;
    private static double epsilon = 1.0E-12d;
    private static double tol = 0.001d;

    public LSSVM() {
        this(new LinearKernel());
    }

    public LSSVM(KernelTrick kernelTrick) {
        this(kernelTrick, SupportVectorLearner.CacheMode.NONE);
    }

    public LSSVM(KernelTrick kernelTrick, SupportVectorLearner.CacheMode cacheMode) {
        super(kernelTrick, cacheMode);
        this.b = 0.0d;
        this.C = 1.0d;
    }

    public LSSVM(LSSVM lssvm) {
        super(lssvm.getKernel().m159clone(), lssvm.getCacheMode());
        this.b = 0.0d;
        this.C = 1.0d;
        this.b_low = lssvm.b_low;
        this.b_up = lssvm.b_up;
        this.i_up = lssvm.i_up;
        this.i_low = lssvm.i_low;
        this.C = lssvm.C;
        if (lssvm.alphas != null) {
            this.alphas = Arrays.copyOf(lssvm.alphas, lssvm.alphas.length);
        }
        if (lssvm.fcache != null) {
            this.fcache = Arrays.copyOf(lssvm.fcache, lssvm.fcache.length);
        }
    }

    @Parameter.WarmParameter(prefLowToHigh = true)
    public void setC(double d) {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("C must be in (0, Infty), not " + d);
        }
        this.C = d;
    }

    public double getC() {
        return this.C;
    }

    private boolean takeStep(int i, int i2, ExecutorService executorService, boolean z) throws InterruptedException, ExecutionException {
        double d = this.alphas[i];
        double d2 = this.alphas[i2];
        double d3 = this.fcache[i];
        double d4 = this.fcache[i2];
        double d5 = d + d2;
        double kEval = ((2.0d * kEval(i2, i)) - kEval(i, i)) - kEval(i2, i2);
        double d6 = d2 - ((d3 - d4) / kEval);
        if (Math.abs(d6 - d2) < epsilon * (d6 + d2 + epsilon)) {
            return false;
        }
        double d7 = d5 - d6;
        this.alphas[i] = d7;
        this.alphas[i2] = d6;
        double d8 = (d3 - d4) / kEval;
        this.dualObjective -= ((kEval / 2.0d) * d8) * d8;
        this.b_up = Double.NEGATIVE_INFINITY;
        this.b_low = Double.POSITIVE_INFINITY;
        ParallelUtils.run(z, this.fcache.length, (i3, i4) -> {
            int i3 = i3;
            int i4 = i3;
            double d9 = Double.NEGATIVE_INFINITY;
            double d10 = Double.POSITIVE_INFINITY;
            for (int i5 = i3; i5 < i4; i5++) {
                double kEval2 = kEval(i, i5);
                double kEval3 = kEval(i2, i5);
                double[] dArr = this.fcache;
                int i6 = i5;
                double d11 = dArr[i6] + ((d7 - d) * kEval2) + ((d6 - d2) * kEval3);
                dArr[i6] = d11;
                if (d11 > d9) {
                    d9 = d11;
                    i4 = i5;
                }
                if (d11 < d10) {
                    d10 = d11;
                    i3 = i5;
                }
            }
            synchronized (this.fcache) {
                if (this.fcache[i4] > this.b_up) {
                    this.b_up = this.fcache[i4];
                    this.i_up = i4;
                }
                if (this.fcache[i3] < this.b_low) {
                    this.b_low = this.fcache[i3];
                    this.i_low = i3;
                }
            }
        }, executorService);
        return true;
    }

    @Override // jsat.regression.WarmRegressor
    public boolean warmFromSameDataOnly() {
        return true;
    }

    private double computeDualityGap(boolean z, boolean z2) throws InterruptedException, ExecutionException {
        if (z) {
            this.b = (this.b_up + this.b_low) / 2.0d;
        } else {
            this.b = ParallelUtils.streamP(IntStream.range(0, this.alphas.length), z2).mapToDouble(i -> {
                return this.fcache[i] - (this.alphas[i] / this.C);
            }).sum();
            this.b /= this.alphas.length;
        }
        return ParallelUtils.streamP(IntStream.range(0, this.alphas.length), z2).mapToDouble(i2 -> {
            double d = (this.b + (this.alphas[i2] / this.C)) - this.fcache[i2];
            return (this.alphas[i2] * (this.fcache[i2] - ((0.5d * this.alphas[i2]) / this.C))) + (((this.C * d) * d) / 2.0d);
        }).sum();
    }

    private void initializeVariables(double[] dArr, LSSVM lssvm, DataSet dataSet) {
        this.alphas = new double[dArr.length];
        this.fcache = new double[dArr.length];
        this.dualObjective = 0.0d;
        if (lssvm == null) {
            for (int i = 0; i < dArr.length; i++) {
                this.fcache[i] = -dArr[i];
            }
        } else {
            if (lssvm.alphas.length != this.alphas.length) {
                throw new FailedToFitException("Warm LS-SVM solution could not have been trained on the sama data, different number of alpha values present");
            }
            double d = this.C / lssvm.C;
            for (int i2 = 0; i2 < dArr.length; i2++) {
                this.alphas[i2] = lssvm.alphas[i2];
                this.fcache[i2] = lssvm.fcache[i2] - (((d - 1.0d) * lssvm.alphas[i2]) / this.C);
                this.dualObjective += this.alphas[i2] * (dArr[i2] - this.fcache[i2]);
            }
            this.dualObjective /= 2.0d;
        }
        this.b_up = Double.NEGATIVE_INFINITY;
        this.b_low = Double.POSITIVE_INFINITY;
        for (int i3 = 0; i3 < this.fcache.length; i3++) {
            double d2 = this.fcache[i3];
            if (d2 > this.b_up) {
                this.b_up = d2;
                this.i_up = i3;
            }
            if (d2 < this.b_low) {
                this.b_low = d2;
                this.i_low = i3;
            }
        }
        setCacheMode(getCacheMode());
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return regress(dataPoint);
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (regress(dataPoint) > 0.0d) {
            categoricalResults.setProb(1, 1.0d);
        } else {
            categoricalResults.setProb(0, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        train(classificationDataSet, (Classifier) null, z);
    }

    @Override // jsat.regression.WarmRegressor
    public void train(RegressionDataSet regressionDataSet, Regressor regressor, boolean z) {
        if (regressor != null && !(regressor instanceof LSSVM)) {
            throw new FailedToFitException("Warm solution must be an implementation of LS-SVM, not " + regressor.getClass());
        }
        mainLoop(regressionDataSet, (LSSVM) regressor, regressionDataSet.getTargetValues().arrayCopy(), z);
    }

    @Override // jsat.regression.WarmRegressor
    public void train(RegressionDataSet regressionDataSet, Regressor regressor) {
        train(regressionDataSet, regressor, false);
    }

    @Override // jsat.classifiers.WarmClassifier
    public void train(ClassificationDataSet classificationDataSet, Classifier classifier, boolean z) {
        if (classificationDataSet.getClassSize() != 2) {
            throw new FailedToFitException("LS-SVM only supports binary classification problems");
        }
        if (classifier != null && !(classifier instanceof LSSVM)) {
            throw new FailedToFitException("Warm solution must be an implementation of LS-SVM, not " + classifier.getClass());
        }
        double[] dArr = new double[classificationDataSet.getSampleSize()];
        for (int i = 0; i < classificationDataSet.getSampleSize(); i++) {
            dArr[i] = (classificationDataSet.getDataPointCategory(i) * 2) - 1;
        }
        mainLoop(classificationDataSet, (LSSVM) classifier, dArr, z);
    }

    @Override // jsat.classifiers.WarmClassifier
    public void train(ClassificationDataSet classificationDataSet, Classifier classifier) {
        train(classificationDataSet, classifier, false);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        return kEvalSum(dataPoint.getNumericalValues()) - this.b;
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, boolean z) {
        train(regressionDataSet, (Regressor) null, z);
    }

    @Override // jsat.regression.Regressor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public LSSVM mo107clone() {
        return new LSSVM(this);
    }

    private void mainLoop(DataSet dataSet, LSSVM lssvm, double[] dArr, boolean z) {
        try {
            ExecutorService newExecutor = ParallelUtils.getNewExecutor(z);
            this.vecs = dataSet.getDataVectors();
            initializeVariables(dArr, lssvm, dataSet);
            boolean z2 = true;
            double computeDualityGap = computeDualityGap(true, z);
            int i = 0;
            while (computeDualityGap > tol * this.dualObjective && z2) {
                z2 = takeStep(this.i_up, this.i_low, newExecutor, z);
                computeDualityGap = computeDualityGap(true, z);
                i++;
            }
            setCacheMode(null);
            setAlphas(this.alphas);
        } catch (InterruptedException e) {
            throw new FailedToFitException(e);
        } catch (ExecutionException e2) {
            throw new FailedToFitException(e2);
        }
    }

    public static Distribution guessC(DataSet dataSet) {
        return PlattSMO.guessC(dataSet);
    }
}
