package jsat.classifiers.svm.extended;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.Uniform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.MatrixOfVecs;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;
import jsat.parameters.Parameterized;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

/* loaded from: input_file:jsat/classifiers/svm/extended/CPM.class */
public class CPM implements BinaryScoreClassifier, Classifier, Parameterized {
    private static final long serialVersionUID = 3171068484917637037L;
    private int epochs;
    private double lambda;
    private int K;
    private double entropyThreshold;
    private double h;
    private Matrix Wp;
    private Matrix Wn;
    private Vec bp;
    private Vec bn;

    public CPM() {
        this(1.0d);
    }

    public CPM(int i) {
        this(1.0d, i);
    }

    public CPM(double d) {
        this(d, 16);
    }

    public CPM(double d, int i) {
        this(d, i, 3.0d);
    }

    public CPM(double d, int i, double d2) {
        this(d, i, d2, 50);
    }

    public CPM(double d, int i, double d2, int i2) {
        setEpochs(i2);
        setLambda(d);
        setK(i);
        setEntropyThreshold(d2);
    }

    public CPM(CPM cpm) {
        this.epochs = cpm.epochs;
        this.lambda = cpm.lambda;
        this.K = cpm.K;
        this.entropyThreshold = cpm.entropyThreshold;
        this.h = cpm.h;
        if (cpm.Wp != null) {
            this.Wp = cpm.Wp.mo171clone();
        }
        if (cpm.Wn != null) {
            this.Wn = cpm.Wn.mo171clone();
        }
        if (cpm.bp != null) {
            this.bp = cpm.bp.mo46clone();
        }
        if (cpm.bn != null) {
            this.bn = cpm.bn.mo46clone();
        }
    }

    public void setEntropyThreshold(double d) {
        if (d < 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Entropy threshold must be non-negative, not " + d);
        }
        this.entropyThreshold = d;
        set_h_properly();
    }

    private void set_h_properly() {
        this.h = Math.log((this.entropyThreshold * this.K) / 10.0d) / Math.log(2.0d);
        if (this.h <= 0.0d) {
            this.h = 0.0d;
        }
    }

    public double getEntropyThreshold() {
        return this.entropyThreshold;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setK(int i) {
        this.K = i;
        set_h_properly();
    }

    public int getK() {
        return this.K;
    }

    public void setEpochs(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("epochs must be a positive value");
        }
        this.epochs = i;
    }

    public int getEpochs() {
        return this.epochs;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        double max = this.Wp.multiply(numericalValues).add(this.bp).max();
        double max2 = this.Wn.multiply(numericalValues).add(this.bn).max();
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (max2 <= 0.0d || max <= 0.0d) {
            if (max2 > 0.0d) {
                categoricalResults.setProb(0, 1.0d);
            } else if (max > 0.0d) {
                categoricalResults.setProb(1, 1.0d);
            } else if (max2 > max) {
                categoricalResults.setProb(0, 1.0d);
            } else {
                categoricalResults.setProb(1, 1.0d);
            }
        } else if (max2 > max) {
            categoricalResults.setProb(0, 1.0d);
        } else {
            categoricalResults.setProb(1, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        return this.Wp.multiply(numericalValues).add(this.bp).max() - this.Wn.multiply(numericalValues).add(this.bn).max();
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    private int ASSIGN(Vec vec, int i, int i2, int[] iArr, int[] iArr2, int i3) {
        int i4 = iArr2[i];
        double d = 0.0d;
        double d2 = Double.POSITIVE_INFINITY;
        int i5 = 0;
        if (i3 > this.K * 10) {
            double d3 = 0.0d;
            for (int i6 = 0; i6 < this.K; i6++) {
                i5 = Math.max(i5, iArr[i6]);
                double d4 = iArr[i6];
                double d5 = i3;
                if (d4 > 0.0d) {
                    d += ((-d4) * (Math.log(d4) - Math.log(d5))) / (Math.log(2.0d) * d5);
                }
                if (i4 < 0) {
                    double d6 = d5 + 1.0d;
                    if (i6 == i2) {
                        d4 += 1.0d;
                    }
                    if (d4 > 0.0d) {
                        d3 += ((-d4) * (Math.log(d4) - Math.log(d6))) / (Math.log(2.0d) * d6);
                    }
                } else if (i4 == i2) {
                    d3 = d;
                } else {
                    if (i6 == i2) {
                        d4 += 1.0d;
                    } else if (i6 == i4) {
                        d4 -= 1.0d;
                    }
                    if (d4 > 0.0d) {
                        d3 += ((-d4) * (Math.log(d4) - Math.log(d5))) / (Math.log(2.0d) * d5);
                    }
                }
            }
            d2 = d3 + d;
        }
        if (d2 >= this.h) {
            return i2;
        }
        int i7 = 0;
        if (i4 >= 0) {
            for (int i8 = 1; i8 < vec.length(); i8++) {
                if (iArr[i4] > iArr[i8] && vec.get(i8) > vec.get(i7)) {
                    i7 = i8;
                }
            }
        } else {
            double d7 = Double.NEGATIVE_INFINITY;
            for (int i9 = 1; i9 < vec.length(); i9++) {
                if (i5 > iArr[i9] && vec.get(i9) > d7) {
                    i7 = i9;
                    d7 = vec.get(i9);
                }
            }
            if (Double.isInfinite(d7)) {
                return i2;
            }
        }
        return i7;
    }

    private void sgdTrain(ClassificationDataSet classificationDataSet, MatrixOfVecs matrixOfVecs, Vec vec, int i, boolean z) {
        IntList intList = new IntList(classificationDataSet.getSampleSize());
        ListUtils.addRange(intList, 0, classificationDataSet.getSampleSize(), 1);
        double sampleSize = this.lambda / (classificationDataSet.getSampleSize() * this.epochs);
        int[] iArr = new int[this.K];
        int i2 = 0;
        int[] iArr2 = new int[classificationDataSet.getSampleSize()];
        Arrays.fill(iArr2, -1);
        DenseVector denseVector = new DenseVector(matrixOfVecs.rows());
        long j = 0;
        for (int i3 = 0; i3 < this.epochs; i3++) {
            Collections.shuffle(intList);
            Iterator<Integer> it = intList.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                j++;
                double d = 1.0d / (sampleSize * j);
                Vec numericalValues = classificationDataSet.getDataPoint(intValue).getNumericalValues();
                int dataPointCategory = ((classificationDataSet.getDataPointCategory(intValue) * 2) - 1) * i;
                vec.copyTo(denseVector);
                matrixOfVecs.multiply(numericalValues, 1.0d, denseVector);
                if (dataPointCategory == -1) {
                    for (int i4 = 0; i4 < this.K; i4++) {
                        if (denseVector.get(i4) > -1.0d) {
                            matrixOfVecs.getRowView(i4).mutableSubtract(d, numericalValues);
                            vec.increment(i4, -d);
                        }
                    }
                } else {
                    int i5 = 0;
                    for (int i6 = 1; i6 < denseVector.length(); i6++) {
                        if (denseVector.get(i6) > denseVector.get(i5)) {
                            i5 = i6;
                        }
                    }
                    if (denseVector.get(i5) < 1.0d) {
                        int ASSIGN = ASSIGN(denseVector, intValue, i5, iArr, iArr2, i2);
                        matrixOfVecs.getRowView(ASSIGN).mutableAdd(d, numericalValues);
                        vec.increment(ASSIGN, d);
                        if (iArr2[intValue] < 0) {
                            i2++;
                        } else {
                            int i7 = iArr2[intValue];
                            iArr[i7] = iArr[i7] - 1;
                        }
                        iArr[ASSIGN] = iArr[ASSIGN] + 1;
                        iArr2[intValue] = ASSIGN;
                    }
                }
                matrixOfVecs.mutableMultiply(1.0d - (1.0d / j));
                vec.mutableMultiply(1.0d - (1.0d / j));
            }
        }
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        if (classificationDataSet.getPredicting().getNumOfCategories() > 2) {
            throw new FailedToFitException("CPM is a binary classifier, it can not be trained on a dataset with " + classificationDataSet.getPredicting().getNumOfCategories() + " classes");
        }
        int numNumericalVars = classificationDataSet.getNumNumericalVars();
        ArrayList arrayList = new ArrayList(this.K);
        ArrayList arrayList2 = new ArrayList(this.K);
        this.bp = new DenseVector(this.K);
        this.bn = new DenseVector(this.K);
        for (int i = 0; i < this.K; i++) {
            arrayList.add(new ScaledVector(new DenseVector(numNumericalVars)));
            arrayList2.add(new ScaledVector(new DenseVector(numNumericalVars)));
        }
        MatrixOfVecs matrixOfVecs = new MatrixOfVecs(arrayList);
        MatrixOfVecs matrixOfVecs2 = new MatrixOfVecs(arrayList2);
        sgdTrain(classificationDataSet, matrixOfVecs, this.bp, 1, z);
        sgdTrain(classificationDataSet, matrixOfVecs2, this.bn, -1, z);
        this.Wp = new DenseMatrix(matrixOfVecs);
        this.Wn = new DenseMatrix(matrixOfVecs2);
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier, jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public CPM m99clone() {
        return new CPM(this);
    }

    public static Distribution guessLambda(DataSet dataSet) {
        return new LogUniform(0.1d, 10000.0d);
    }

    public static Distribution guessEntropyThreshold(DataSet dataSet) {
        return new Uniform(0.1d, 10.0d);
    }
}
