package jsat.classifiers.linear;

import java.util.Iterator;
import jsat.DataSet;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.parameters.Parameterized;

/* loaded from: input_file:jsat/classifiers/linear/AROW.class */
public class AROW extends BaseUpdateableClassifier implements BinaryScoreClassifier, Parameterized, SingleWeightVectorModel {
    private static final long serialVersionUID = 443803827811508204L;
    private Vec w;
    private Matrix sigmaM;
    private Vec sigmaV;
    private boolean diagonalOnly;
    private double r;
    private Vec Sigma_xt;

    public AROW() {
        this(0.01d, true);
    }

    public AROW(double d, boolean z) {
        this.diagonalOnly = false;
        setR(d);
        setDiagonalOnly(z);
    }

    protected AROW(AROW arow) {
        this.diagonalOnly = false;
        this.r = arow.r;
        this.diagonalOnly = arow.diagonalOnly;
        if (arow.w != null) {
            this.w = arow.w.mo46clone();
        }
        if (arow.sigmaM != null) {
            this.sigmaM = arow.sigmaM.mo171clone();
        }
        if (arow.sigmaV != null) {
            this.sigmaV = arow.sigmaV.mo46clone();
        }
        if (arow.Sigma_xt != null) {
            this.Sigma_xt = arow.Sigma_xt.mo46clone();
        }
    }

    public void setDiagonalOnly(boolean z) {
        this.diagonalOnly = z;
    }

    public boolean isDiagonalOnly() {
        return this.diagonalOnly;
    }

    public void setR(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d) || d <= 0.0d) {
            throw new IllegalArgumentException("r must be a postive constant, not " + d);
        }
        this.r = d;
    }

    public double getR() {
        return this.r;
    }

    public Vec getWeightVec() {
        return this.w;
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public AROW mo0clone() {
        return new AROW(this);
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        if (i <= 0) {
            throw new FailedToFitException("AROW requires numeric attributes to perform classification");
        }
        if (categoricalData.getNumOfCategories() != 2) {
            throw new FailedToFitException("AROW is a binary classifier");
        }
        this.w = new DenseVector(i);
        this.Sigma_xt = new DenseVector(i);
        if (!this.diagonalOnly) {
            this.sigmaM = Matrix.eye(i);
        } else {
            this.sigmaV = new DenseVector(i);
            this.sigmaV.mutableAdd(1.0d);
        }
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        Vec numericalValues = dataPoint.getNumericalValues();
        double d = (i * 2) - 1;
        double dot = numericalValues.dot(this.w);
        if (d == Math.signum(dot)) {
            return;
        }
        double d2 = 0.0d;
        if (this.diagonalOnly) {
            Iterator<IndexValue> it = numericalValues.iterator();
            while (it.hasNext()) {
                IndexValue next = it.next();
                double value = next.getValue();
                d2 += value * value * this.sigmaV.get(next.getIndex());
            }
        } else {
            this.sigmaM.multiply(numericalValues, 1.0d, this.Sigma_xt);
            d2 = numericalValues.dot(this.Sigma_xt);
        }
        double d3 = d2 + this.r;
        double max = Math.max(0.0d, 1.0d - (d * dot)) / d3;
        if (this.diagonalOnly) {
            Iterator<IndexValue> it2 = numericalValues.iterator();
            while (it2.hasNext()) {
                IndexValue next2 = it2.next();
                this.w.increment(next2.getIndex(), max * d * next2.getValue() * this.sigmaV.get(next2.getIndex()));
            }
        } else {
            this.w.mutableAdd(max * d, this.Sigma_xt);
        }
        if (this.diagonalOnly) {
            Iterator<IndexValue> it3 = numericalValues.iterator();
            while (it3.hasNext()) {
                IndexValue next3 = it3.next();
                int index = next3.getIndex();
                double value2 = next3.getValue() * this.sigmaV.get(index);
                this.sigmaV.increment(index, (-(value2 * value2)) / d3);
            }
        } else {
            Matrix.OuterProductUpdate(this.sigmaM, this.Sigma_xt, this.Sigma_xt, (-1.0d) / d3);
        }
        if (this.diagonalOnly) {
            this.Sigma_xt.zeroOut();
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.w == null) {
            throw new UntrainedModelException("Model has not yet ben trained");
        }
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (getScore(dataPoint) < 0.0d) {
            categoricalResults.setProb(0, 1.0d);
        } else {
            categoricalResults.setProb(1, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return this.w.dot(dataPoint.getNumericalValues());
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.SingleWeightVectorModel
    public Vec getRawWeight() {
        return this.w;
    }

    @Override // jsat.SingleWeightVectorModel
    public double getBias() {
        return 0.0d;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        if (i < 1) {
            return getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        if (i < 1) {
            return getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return 1;
    }

    public static Distribution guessR(DataSet dataSet) {
        return new LogUniform(Math.pow(2.0d, -4.0d), Math.pow(2.0d, 4.0d));
    }
}
