package jsat.classifiers.linear.kernelized;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.Uniform;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.IntList;

/* loaded from: input_file:jsat/classifiers/linear/kernelized/ALMA2K.class */
public class ALMA2K extends BaseUpdateableClassifier implements BinaryScoreClassifier, Parameterized {
    private static final long serialVersionUID = 7247320234799227009L;
    private static final double p = 2.0d;
    private double alpha;
    private double B;
    private double C;
    private int k;
    private int curRounds;

    @Parameter.ParameterHolder
    private KernelTrick K;
    private List<Vec> supports;
    private List<Double> signedEtas;
    private List<Double> associatedScores;
    private List<Double> normalizers;
    private List<Integer> rounds;
    private boolean averaged;

    public ALMA2K(KernelTrick kernelTrick, double d) {
        this.C = Math.sqrt(2.0d);
        this.averaged = false;
        setKernelTrick(kernelTrick);
        setAlpha(d);
    }

    protected ALMA2K(ALMA2K alma2k) {
        this.C = Math.sqrt(2.0d);
        this.averaged = false;
        this.alpha = alma2k.alpha;
        this.B = alma2k.B;
        this.C = alma2k.C;
        this.k = alma2k.k;
        this.K = alma2k.K.mo154clone();
        this.averaged = alma2k.averaged;
        if (alma2k.supports != null) {
            this.supports = new ArrayList(alma2k.supports.size());
            Iterator<Vec> it = alma2k.supports.iterator();
            while (it.hasNext()) {
                this.supports.add(it.next().mo46clone());
            }
            this.signedEtas = new DoubleList(alma2k.signedEtas);
            this.associatedScores = new DoubleList(alma2k.associatedScores);
            this.normalizers = new DoubleList(alma2k.normalizers);
            this.rounds = new IntList(alma2k.rounds);
        }
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public ALMA2K mo0clone() {
        return new ALMA2K(this);
    }

    public void setAveraged(boolean z) {
        this.averaged = z;
    }

    public boolean isAveraged() {
        return this.averaged;
    }

    public void setKernelTrick(KernelTrick kernelTrick) {
        this.K = kernelTrick;
    }

    public KernelTrick getKernelTrick() {
        return this.K;
    }

    public void setAlpha(double d) {
        if (d <= 0.0d || d > 1.0d || Double.isNaN(d)) {
            throw new ArithmeticException("alpha must be in (0, 1], not " + d);
        }
        this.alpha = d;
        setB(1.0d / d);
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setB(double d) {
        this.B = d;
    }

    public double getB() {
        return this.B;
    }

    public void setC(double d) {
        if (d <= 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new ArithmeticException("C must be a posative cosntant");
        }
        this.C = d;
    }

    public double getC() {
        return this.C;
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        if (i <= 0) {
            throw new FailedToFitException("ALMA2 requires numeric features");
        }
        if (categoricalData.getNumOfCategories() != 2) {
            throw new FailedToFitException("ALMA2 works only for binary classification");
        }
        this.supports = new ArrayList();
        this.signedEtas = new DoubleList();
        this.associatedScores = new DoubleList();
        this.normalizers = new DoubleList();
        this.rounds = new IntList();
        this.k = 1;
        this.curRounds = 0;
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        Vec numericalValues = dataPoint.getNumericalValues();
        double d = (i * 2) - 1;
        if (d * score(numericalValues, false) > (1.0d - this.alpha) * ((this.B * Math.sqrt(1.0d)) / this.k)) {
            this.curRounds++;
            return;
        }
        double sqrt = this.C / Math.sqrt(1.0d);
        int i2 = this.k;
        this.k = i2 + 1;
        double sqrt2 = sqrt / Math.sqrt(i2);
        double sqrt3 = Math.sqrt(this.K.eval(numericalValues, numericalValues));
        this.associatedScores.add(Double.valueOf(score(new ScaledVector(1.0d / sqrt3, numericalValues), false)));
        this.supports.add(numericalValues);
        this.normalizers.add(Double.valueOf(sqrt3));
        this.signedEtas.add(Double.valueOf(sqrt2 * d));
        this.rounds.add(Integer.valueOf(this.curRounds));
        this.curRounds = 0;
    }

    private double score(Vec vec, boolean z) {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < this.supports.size(); i++) {
            double doubleValue = this.signedEtas.get(i).doubleValue();
            double eval = (doubleValue * this.K.eval(this.supports.get(i), vec)) / this.normalizers.get(i).doubleValue();
            d2 += (d2 / Math.max(1.0d, d2)) + (2.0d * doubleValue * this.associatedScores.get(i).doubleValue()) + (doubleValue * doubleValue);
            d += eval / Math.max(1.0d, d2);
            if (z) {
                d3 += d * this.rounds.get(i).intValue();
            }
        }
        return z ? d3 : d;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        double score = getScore(dataPoint);
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (score < 0.0d) {
            categoricalResults.setProb(0, 1.0d);
        } else {
            categoricalResults.setProb(1, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return score(dataPoint.getNumericalValues(), this.averaged);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    public static Distribution guessAlpha(DataSet dataSet) {
        return new Uniform(0.001d, 1.0d);
    }
}
