package jsat.classifiers.linear.kernelized;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.linear.Vec;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/classifiers/linear/kernelized/CSKLR.class */
public class CSKLR extends BaseUpdateableClassifier implements Parameterized {
    private static final long serialVersionUID = 2325605193408720811L;
    private double eta;
    private DoubleList alpha;
    private List<Vec> vecs;
    private double curNorm;
    private KernelTrick k;
    private double R;
    private Random rand;
    private UpdateMode mode;
    private double gamma;
    private List<Double> accelCache;

    /* loaded from: input_file:jsat/classifiers/linear/kernelized/CSKLR$UpdateMode.class */
    public enum UpdateMode {
        NC { // from class: jsat.classifiers.linear.kernelized.CSKLR.UpdateMode.1
            @Override // jsat.classifiers.linear.kernelized.CSKLR.UpdateMode
            protected double pt(double d, double d2, double d3, double d4, double d5) {
                return 1.0d;
            }

            @Override // jsat.classifiers.linear.kernelized.CSKLR.UpdateMode
            protected double grad(double d, double d2, double d3, double d4) {
                return d2 - 1.0d;
            }
        },
        MARGIN { // from class: jsat.classifiers.linear.kernelized.CSKLR.UpdateMode.2
            @Override // jsat.classifiers.linear.kernelized.CSKLR.UpdateMode
            protected double pt(double d, double d2, double d3, double d4, double d5) {
                return (2.0d - d4) / ((2.0d - d4) + (d4 * d2));
            }

            @Override // jsat.classifiers.linear.kernelized.CSKLR.UpdateMode
            protected double grad(double d, double d2, double d3, double d4) {
                return d2 - 1.0d;
            }
        },
        AUXILIARY_1 { // from class: jsat.classifiers.linear.kernelized.CSKLR.UpdateMode.3
            @Override // jsat.classifiers.linear.kernelized.CSKLR.UpdateMode
            protected double pt(double d, double d2, double d3, double d4, double d5) {
                double d6 = d * d3;
                return Math.log(1.0d + Math.exp(-d6)) / Math.log(d5 + Math.exp(-d6));
            }

            @Override // jsat.classifiers.linear.kernelized.CSKLR.UpdateMode
            protected double grad(double d, double d2, double d3, double d4) {
                return (-1.0d) / (1.0d + (d4 * Math.exp(d * d3)));
            }
        },
        AUXILIARY_2 { // from class: jsat.classifiers.linear.kernelized.CSKLR.UpdateMode.4
            @Override // jsat.classifiers.linear.kernelized.CSKLR.UpdateMode
            protected double pt(double d, double d2, double d3, double d4, double d5) {
                double d6 = d * d3;
                return Math.log(1.0d + Math.exp(-d6)) / Math.log(1.0d + (d5 * Math.exp(-d6)));
            }

            @Override // jsat.classifiers.linear.kernelized.CSKLR.UpdateMode
            protected double grad(double d, double d2, double d3, double d4) {
                return (-d4) / (d4 + Math.exp(d * d3));
            }
        },
        AUXILIARY_3 { // from class: jsat.classifiers.linear.kernelized.CSKLR.UpdateMode.5
            @Override // jsat.classifiers.linear.kernelized.CSKLR.UpdateMode
            protected double pt(double d, double d2, double d3, double d4, double d5) {
                return Math.log(1.0d + Math.exp(-(d * d3))) / Math.log(1.0d + Math.exp(-d5));
            }

            @Override // jsat.classifiers.linear.kernelized.CSKLR.UpdateMode
            protected double grad(double d, double d2, double d3, double d4) {
                return d2 - 1.0d;
            }
        };

        /* JADX INFO: Access modifiers changed from: protected */
        public abstract double pt(double d, double d2, double d3, double d4, double d5);

        /* JADX INFO: Access modifiers changed from: protected */
        public abstract double grad(double d, double d2, double d3, double d4);
    }

    public CSKLR(double d, KernelTrick kernelTrick, double d2, UpdateMode updateMode) {
        this.gamma = 2.0d;
        setEta(d);
        setKernel(kernelTrick);
        setR(d2);
        setMode(updateMode);
    }

    public static Distribution guessR(DataSet dataSet) {
        return new LogUniform(1.0d, 100000.0d);
    }

    protected CSKLR(CSKLR csklr) {
        this.gamma = 2.0d;
        if (csklr.alpha != null) {
            this.alpha = new DoubleList(csklr.alpha);
        }
        if (csklr.vecs != null) {
            this.vecs = new ArrayList(csklr.vecs);
        }
        this.curNorm = csklr.curNorm;
        this.mode = csklr.mode;
        this.R = csklr.R;
        this.eta = csklr.eta;
        setKernel(csklr.k.mo154clone());
        if (csklr.accelCache != null) {
            this.accelCache = new DoubleList(csklr.accelCache);
        }
        this.gamma = csklr.gamma;
        this.rand = RandomUtil.getRandom();
        setEpochs(csklr.getEpochs());
    }

    public void setEta(double d) {
        if (d < 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("The learning rate should be in (0, Inf), not " + d);
        }
        this.eta = d;
    }

    public double getEta() {
        return this.eta;
    }

    public void setR(double d) {
        if (d < 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("The max norm should be in (0, Inf), not " + d);
        }
        this.R = d;
    }

    public double getR() {
        return this.R;
    }

    public void setMode(UpdateMode updateMode) {
        this.mode = updateMode;
    }

    public UpdateMode getMode() {
        return this.mode;
    }

    public void setGamma(double d) {
        if (d < 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Gamma must be in (0, Infity), not " + d);
        }
        this.gamma = d;
    }

    public double getGamma() {
        return this.gamma;
    }

    public void setKernel(KernelTrick kernelTrick) {
        this.k = kernelTrick;
    }

    public KernelTrick getKernel() {
        return this.k;
    }

    private double getPreScore(Vec vec) {
        return this.k.evalSum(this.vecs, this.accelCache, this.alpha.getBackingArray(), vec, 0, this.alpha.size());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double getScore(double d, double d2) {
        return 1.0d / (1.0d + Math.exp((-d) * d2));
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public CSKLR mo0clone() {
        return new CSKLR(this);
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        if (categoricalData.getNumOfCategories() != 2) {
            throw new FailedToFitException("CSKLR supports only binary classification");
        }
        this.alpha = new DoubleList();
        this.vecs = new ArrayList();
        this.curNorm = 0.0d;
        this.rand = RandomUtil.getRandom();
        if (this.k.supportsAcceleration()) {
            this.accelCache = new DoubleList();
        }
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        double d = (i * 2) - 1;
        Vec numericalValues = dataPoint.getNumericalValues();
        double preScore = getPreScore(numericalValues);
        double score = getScore(d, preScore);
        switch (this.mode) {
            case NC:
                break;
            default:
                if (this.rand.nextDouble() > this.mode.pt(d, score, preScore, this.eta, this.gamma)) {
                    return;
                }
                break;
        }
        double grad = (-this.eta) * d * this.mode.grad(d, score, preScore, this.gamma) * dataPoint.getWeight();
        this.alpha.add(grad);
        this.vecs.add(numericalValues);
        this.k.addToCache(numericalValues, this.accelCache);
        this.curNorm += Math.abs(grad) * this.k.eval(this.vecs.size(), this.vecs.size(), this.vecs, this.accelCache);
        if (this.curNorm > this.R) {
            double d2 = this.R / this.curNorm;
            for (int i2 = 0; i2 < this.alpha.size(); i2++) {
                this.alpha.set(i2, this.alpha.get(i2).doubleValue() * d2);
            }
            this.curNorm = d2;
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(2);
        double score = getScore(-1.0d, getPreScore(dataPoint.getNumericalValues()));
        categoricalResults.setProb(0, score);
        categoricalResults.setProb(1, 1.0d - score);
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }
}
