package cc.mallet.classify.constraints.pr;

import cc.mallet.classify.constraints.pr.MaxEntFLPRConstraints;
import cc.mallet.types.FeatureVector;
import com.carrotsearch.hppc.IntIntHashMap;
import com.carrotsearch.hppc.cursors.IntObjectCursor;
import java.util.Iterator;

/* loaded from: input_file:cc/mallet/classify/constraints/pr/MaxEntL2FLPRConstraints.class */
public class MaxEntL2FLPRConstraints extends MaxEntFLPRConstraints {
    private IntIntHashMap constraintIndices;
    private boolean normalize;

    /* loaded from: input_file:cc/mallet/classify/constraints/pr/MaxEntL2FLPRConstraints$MaxEntL2FLPRConstraint.class */
    protected class MaxEntL2FLPRConstraint extends MaxEntFLPRConstraints.MaxEntFLPRConstraint {
        public MaxEntL2FLPRConstraint(double[] dArr, double d) {
            super(dArr, d);
        }
    }

    public MaxEntL2FLPRConstraints(int i, int i2, boolean z, boolean z2) {
        super(i, i2, z);
        this.constraintIndices = new IntIntHashMap();
        this.normalize = z2;
    }

    @Override // cc.mallet.classify.constraints.pr.MaxEntFLPRConstraints
    public void addConstraint(int i, double[] dArr, double d) {
        this.constraints.put(i, new MaxEntL2FLPRConstraint(dArr, d));
        this.constraintIndices.put(i, this.constraintIndices.size());
    }

    @Override // cc.mallet.classify.constraints.pr.MaxEntPRConstraint
    public int numDimensions() {
        return this.constraints.size() * this.numLabels;
    }

    @Override // cc.mallet.classify.constraints.pr.MaxEntPRConstraint
    public double getAuxiliaryValueContribution(double[] dArr) {
        double d = 0.0d;
        Iterator it = this.constraints.iterator();
        while (it.hasNext()) {
            IntObjectCursor intObjectCursor = (IntObjectCursor) it.next();
            int i = this.constraintIndices.get(intObjectCursor.key);
            for (int i2 = 0; i2 < this.numLabels; i2++) {
                double d2 = dArr[i + (i2 * this.constraints.size())];
                d = (d + (((MaxEntFLPRConstraints.MaxEntFLPRConstraint) intObjectCursor.value).target[i2] * d2)) - ((d2 * d2) / (2.0d * ((MaxEntFLPRConstraints.MaxEntFLPRConstraint) intObjectCursor.value).weight));
            }
        }
        return d;
    }

    @Override // cc.mallet.classify.constraints.pr.MaxEntPRConstraint
    public void getGradient(double[] dArr, double[] dArr2) {
        Iterator it = this.constraints.iterator();
        while (it.hasNext()) {
            IntObjectCursor intObjectCursor = (IntObjectCursor) it.next();
            int i = this.constraintIndices.get(intObjectCursor.key);
            double d = this.normalize ? ((MaxEntFLPRConstraints.MaxEntFLPRConstraint) intObjectCursor.value).count : 1.0d;
            for (int i2 = 0; i2 < this.numLabels; i2++) {
                double d2 = dArr[i + (i2 * this.constraints.size())];
                dArr2[i + (i2 * this.constraints.size())] = ((MaxEntFLPRConstraints.MaxEntFLPRConstraint) intObjectCursor.value).target[i2] - (((MaxEntFLPRConstraints.MaxEntFLPRConstraint) intObjectCursor.value).expectation[i2] / d);
                int size = i + (i2 * this.constraints.size());
                dArr2[size] = dArr2[size] - (d2 / ((MaxEntFLPRConstraints.MaxEntFLPRConstraint) intObjectCursor.value).weight);
            }
        }
    }

    @Override // cc.mallet.classify.constraints.pr.MaxEntPRConstraint
    public double getCompleteValueContribution() {
        double d = 0.0d;
        Iterator it = this.constraints.iterator();
        while (it.hasNext()) {
            IntObjectCursor intObjectCursor = (IntObjectCursor) it.next();
            double d2 = this.normalize ? ((MaxEntFLPRConstraints.MaxEntFLPRConstraint) intObjectCursor.value).count : 1.0d;
            for (int i = 0; i < this.numLabels; i++) {
                d -= (((MaxEntFLPRConstraints.MaxEntFLPRConstraint) intObjectCursor.value).weight * Math.pow(((MaxEntFLPRConstraints.MaxEntFLPRConstraint) intObjectCursor.value).target[i] - (((MaxEntFLPRConstraints.MaxEntFLPRConstraint) intObjectCursor.value).expectation[i] / d2), 2.0d)) / 2.0d;
            }
        }
        return d;
    }

    @Override // cc.mallet.classify.constraints.pr.MaxEntPRConstraint
    public double getScore(FeatureVector featureVector, int i, double[] dArr) {
        double d;
        double d2;
        double d3 = 0.0d;
        for (int i2 = 0; i2 < this.indexCache.size(); i2++) {
            double d4 = dArr[this.constraintIndices.get(this.indexCache.get(i2)) + (i * this.constraints.size())];
            double d5 = this.normalize ? ((MaxEntFLPRConstraints.MaxEntFLPRConstraint) this.constraints.get(this.indexCache.get(i2))).count : 1.0d;
            if (this.useValues) {
                d = d3;
                d2 = d4 * this.valueCache.get(i2);
            } else {
                d = d3;
                d2 = d4;
            }
            d3 = d + (d2 / d5);
        }
        return d3;
    }
}
