package jsat.classifiers.linear;

import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.Random;
import jsat.DataSet;
import jsat.SimpleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.WarmClassifier;
import jsat.classifiers.neuralnetwork.LVQ;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.Uniform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.lossfunctions.LogisticLoss;
import jsat.lossfunctions.LossC;
import jsat.lossfunctions.LossFunc;
import jsat.lossfunctions.LossMC;
import jsat.lossfunctions.LossR;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/classifiers/linear/SDCA.class */
public class SDCA implements Classifier, Regressor, Parameterized, SimpleWeightVectorModel, WarmClassifier, WarmRegressor {
    private LossFunc loss;
    private boolean useBias;
    private double tol;
    private double lambda;
    private double alpha;
    private int max_epochs;
    private double[] dual_alphas;
    protected int epochs_taken;
    private Vec[] ws;
    private double[] bs;

    public SDCA() {
        this(1.0E-5d);
    }

    public SDCA(double d) {
        this(d, new LogisticLoss());
    }

    public SDCA(double d, LossFunc lossFunc) {
        this.useBias = true;
        this.tol = 0.001d;
        this.alpha = 0.5d;
        this.max_epochs = LVQ.DEFAULT_ITERATIONS;
        setLoss(lossFunc);
        setLambda(d);
    }

    public SDCA(SDCA sdca) {
        this.useBias = true;
        this.tol = 0.001d;
        this.alpha = 0.5d;
        this.max_epochs = LVQ.DEFAULT_ITERATIONS;
        this.loss = sdca.loss.m217clone();
        this.useBias = sdca.useBias;
        this.tol = sdca.tol;
        this.lambda = sdca.lambda;
        this.alpha = sdca.alpha;
        this.max_epochs = sdca.max_epochs;
        this.epochs_taken = sdca.epochs_taken;
        if (sdca.dual_alphas != null) {
            this.dual_alphas = Arrays.copyOf(sdca.dual_alphas, sdca.dual_alphas.length);
        }
        if (sdca.ws != null) {
            this.ws = new Vec[sdca.ws.length];
            this.bs = new double[sdca.bs.length];
            for (int i = 0; i < sdca.ws.length; i++) {
                this.ws[i] = sdca.ws[i].mo46clone();
                this.bs[i] = sdca.bs[i];
            }
        }
    }

    public void setUseBias(boolean z) {
        this.useBias = z;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    @Parameter.WarmParameter(prefLowToHigh = false)
    public void setLambda(double d) {
        if (d <= 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("Regularization term lambda must be a positive value, not " + d);
        }
        this.lambda = d;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setAlpha(double d) {
        if (d < 0.0d || d > 1.0d || Double.isNaN(d)) {
            throw new IllegalArgumentException("alpha must be in [0, 1], not " + d);
        }
        this.alpha = d;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setMaxIters(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of training iterations must be positive, not " + i);
        }
        this.max_epochs = i;
    }

    public int getMaxIters() {
        return this.max_epochs;
    }

    public void setTolerance(double d) {
        if (d <= 0.0d || Double.isNaN(d)) {
            throw new IllegalArgumentException("convergence tolerance paramter must be positive, not " + d);
        }
        this.tol = d;
    }

    public double getTolerance() {
        return this.tol;
    }

    public void setLoss(LossFunc lossFunc) {
        this.loss = lossFunc;
    }

    public LossFunc getLoss() {
        return this.loss;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        if (this.ws.length == 1) {
            return ((LossC) this.loss).getClassification(this.ws[0].dot(numericalValues) + this.bs[0]);
        }
        DenseVector denseVector = new DenseVector(this.ws.length);
        for (int i = 0; i < this.ws.length; i++) {
            denseVector.set(i, this.ws[i].dot(numericalValues) + this.bs[i]);
        }
        ((LossMC) this.loss).process(denseVector, denseVector);
        return ((LossMC) this.loss).getClassification(denseVector);
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        return ((LossR) this.loss).getRegression(this.ws[0].dot(dataPoint.getNumericalValues()) + this.bs[0]);
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        train(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet) {
        if (classificationDataSet.getPredicting().getNumOfCategories() != 2) {
            throw new RuntimeException("Current SDCA implementation only support binary classification problems");
        }
        double[] dArr = new double[classificationDataSet.getSampleSize()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = (classificationDataSet.getDataPointCategory(i) * 2) - 1;
        }
        trainProxSDCA(classificationDataSet, dArr, null);
    }

    @Override // jsat.classifiers.WarmClassifier
    public void train(ClassificationDataSet classificationDataSet, Classifier classifier, boolean z) {
        train(classificationDataSet, classifier);
    }

    @Override // jsat.classifiers.WarmClassifier
    public void train(ClassificationDataSet classificationDataSet, Classifier classifier) {
        if (classifier == null || !(classifier instanceof SDCA)) {
            throw new FailedToFitException("SDCA implementation can only be warm-started from another instance of SDCA");
        }
        if (classificationDataSet.getPredicting().getNumOfCategories() != 2) {
            throw new RuntimeException("Current SDCA implementation only support binary classification problems");
        }
        double[] dArr = new double[classificationDataSet.getSampleSize()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = (classificationDataSet.getDataPointCategory(i) * 2) - 1;
        }
        trainProxSDCA(classificationDataSet, dArr, ((SDCA) classifier).dual_alphas);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, boolean z) {
        train(regressionDataSet);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        double[] dArr = new double[regressionDataSet.getSampleSize()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = regressionDataSet.getTargetValue(i);
        }
        trainProxSDCA(regressionDataSet, dArr, null);
    }

    @Override // jsat.regression.WarmRegressor
    public void train(RegressionDataSet regressionDataSet, Regressor regressor, boolean z) {
        train(regressionDataSet, regressor);
    }

    @Override // jsat.regression.WarmRegressor
    public void train(RegressionDataSet regressionDataSet, Regressor regressor) {
        double[] dArr = new double[regressionDataSet.getSampleSize()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = regressionDataSet.getTargetValue(i);
        }
        trainProxSDCA(regressionDataSet, dArr, ((SDCA) regressor).dual_alphas);
    }

    private void trainProxSDCA(DataSet dataSet, double[] dArr, double[] dArr2) {
        double d;
        double d2;
        double d3;
        double[] dArr3;
        DenseVector denseVector;
        int sampleSize = dataSet.getSampleSize();
        int numNumericalVars = dataSet.getNumNumericalVars();
        this.ws = new Vec[]{new DenseVector(numNumericalVars)};
        DenseVector denseVector2 = new DenseVector(numNumericalVars);
        this.bs = new double[1];
        double[] dArr4 = new double[sampleSize];
        double d4 = 1.0d;
        boolean z = dataSet instanceof RegressionDataSet;
        for (int i = 0; i < sampleSize; i++) {
            dArr4[i] = dataSet.getDataPoint(i).getNumericalValues().pNorm(2.0d);
            if (!z) {
                d4 = Math.max(d4, dArr4[i]);
            }
        }
        for (int i2 = 0; i2 < sampleSize; i2++) {
            int i3 = i2;
            dArr4[i3] = dArr4[i3] / d4;
        }
        if (this.alpha == 1.0d) {
            double d5 = 0.0d;
            for (int i4 = 0; i4 < sampleSize; i4++) {
                d5 += this.loss.getLoss(0.0d, dArr[i4]);
            }
            d2 = this.lambda;
            d = this.tol * Math.pow(this.lambda / Math.max(d5 / sampleSize, 1.0E-7d), 2.0d);
            d3 = this.tol / 2.0d;
        } else {
            d = this.lambda;
            d2 = this.alpha / (1.0d - this.alpha);
            d3 = this.tol;
        }
        if (this.alpha > 0.0d) {
            dArr3 = new double[numNumericalVars];
            denseVector = new DenseVector(dArr3);
        } else {
            dArr3 = null;
            denseVector = denseVector2;
        }
        if (dArr2 == null) {
            this.dual_alphas = new double[sampleSize];
        } else {
            if (sampleSize != dArr2.length) {
                throw new FailedToFitException("SDCA only supports warm-start training from the same dataset. A dataset of side " + sampleSize + " was given for training, but the warm solution was trained on " + dArr2.length + " points.");
            }
            this.dual_alphas = Arrays.copyOf(dArr2, dArr2.length);
            for (int i5 = 0; i5 < sampleSize; i5++) {
                denseVector2.mutableAdd(this.dual_alphas[i5], dataSet.getDataPoint(i5).getNumericalValues());
                if (this.useBias) {
                    double[] dArr5 = this.bs;
                    dArr5[0] = dArr5[0] + this.dual_alphas[i5];
                }
            }
            denseVector2.mutableDivide(d4 * d * sampleSize);
            double[] dArr6 = this.bs;
            dArr6[0] = dArr6[0] / ((d4 * d) * sampleSize);
        }
        Random random = RandomUtil.getRandom();
        double lipschitz = this.loss.lipschitz();
        IntList intList = new IntList(sampleSize);
        ListUtils.addRange(intList, 0, sampleSize, 1);
        this.epochs_taken = 0;
        int i6 = 0;
        for (int i7 = 0; i7 < this.max_epochs; i7++) {
            this.epochs_taken++;
            double d6 = 0.0d;
            double d7 = 0.0d;
            Collections.shuffle(intList, random);
            Iterator<Integer> it = intList.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                double d8 = this.dual_alphas[intValue];
                Vec numericalValues = dataSet.getDataPoint(intValue).getNumericalValues();
                double d9 = dArr[intValue];
                if (this.alpha > 0.0d) {
                    Iterator<IndexValue> it2 = numericalValues.iterator();
                    while (it2.hasNext()) {
                        int index = it2.next().getIndex();
                        double d10 = denseVector2.get(index);
                        dArr3[index] = Math.signum(d10) * Math.max(Math.abs(d10) - d2, 0.0d);
                    }
                }
                double dot = (denseVector.dot(numericalValues) / d4) + this.bs[0];
                double d11 = (-this.loss.getDeriv(dot, d9)) - d8;
                double d12 = d11 * d11;
                if (d12 > 1.0E-32d) {
                    double loss = this.loss.getLoss(dot, d9);
                    double conjugate = this.loss.getConjugate(-d8, dot, d9);
                    double d13 = dArr4[intValue];
                    double min = Math.min(1.0d, (((loss + conjugate) + (dot * d8)) + ((lipschitz * d12) / 2.0d)) / (d12 * (lipschitz + ((d13 * d13) / (d * sampleSize)))));
                    d7 += loss;
                    if (!Double.isInfinite(conjugate)) {
                        d6 += -conjugate;
                    }
                    if (min != 0.0d) {
                        double d14 = min * d11;
                        double[] dArr7 = this.dual_alphas;
                        dArr7[intValue] = dArr7[intValue] + d14;
                        denseVector2.mutableAdd(d14 / ((d4 * d) * sampleSize), numericalValues);
                        if (this.useBias) {
                            double[] dArr8 = this.bs;
                            dArr8[0] = dArr8[0] + (d14 / ((d4 * d) * sampleSize));
                        }
                    }
                }
            }
            if (Math.abs(d7 - d6) / sampleSize < d3) {
                break;
            }
            if (Double.POSITIVE_INFINITY - (d7 / sampleSize) < d3 / 5.0d) {
                int i8 = i6;
                i6++;
                if (i8 > 10) {
                    break;
                }
            } else {
                i6 = 0;
            }
            double d15 = d7 / sampleSize;
        }
        for (int i9 = 0; i9 < numNumericalVars; i9++) {
            double d16 = denseVector2.get(i9);
            this.ws[0].set(i9, (Math.signum(d16) * Math.max(Math.abs(d16) - d2, 0.0d)) / d4);
        }
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.regression.Regressor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public SDCA m99clone() {
        return new SDCA(this);
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        return this.ws[i];
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        return this.bs[i];
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return this.ws.length;
    }

    @Override // jsat.classifiers.WarmClassifier, jsat.regression.WarmRegressor
    public boolean warmFromSameDataOnly() {
        return true;
    }

    public static Distribution guessLambda(DataSet dataSet) {
        int sampleSize = dataSet.getSampleSize();
        return new LogUniform(1.0d / (sampleSize * 50), Math.min(1.0d / (sampleSize / 50), 1.0d));
    }

    public static Distribution guessAlpha(DataSet dataSet) {
        return new Uniform(0.0d, 0.5d);
    }
}
