package cc.mallet.fst.semi_supervised.constraints;

import cc.mallet.fst.SumLattice;
import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.IntIntHashMap;
import com.carrotsearch.hppc.cursors.IntIntCursor;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Iterator;

/* loaded from: input_file:cc/mallet/fst/semi_supervised/constraints/TwoLabelGEConstraints.class */
public abstract class TwoLabelGEConstraints implements GEConstraint {
    protected ArrayList<TwoLabelGEConstraint> constraintsList;
    protected IntIntHashMap constraintsMap;
    protected StateLabelMap map;
    protected IntArrayList cache;

    /* loaded from: input_file:cc/mallet/fst/semi_supervised/constraints/TwoLabelGEConstraints$TwoLabelGEConstraint.class */
    protected abstract class TwoLabelGEConstraint {
        protected double[][] target;
        protected double[][] expectation = (double[][]) null;
        protected double count = 0.0d;
        protected double weight;

        public TwoLabelGEConstraint(double[][] dArr, double d) {
            this.target = dArr;
            this.weight = d;
        }

        public abstract double getValue(int i, int i2);
    }

    public TwoLabelGEConstraints() {
        this.constraintsList = new ArrayList<>();
        this.constraintsMap = new IntIntHashMap();
        this.map = null;
        this.cache = new IntArrayList();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public TwoLabelGEConstraints(ArrayList<TwoLabelGEConstraint> arrayList, IntIntHashMap intIntHashMap, StateLabelMap stateLabelMap) {
        this.constraintsList = arrayList;
        this.constraintsMap = intIntHashMap;
        this.map = stateLabelMap;
        this.cache = new IntArrayList();
    }

    public abstract void addConstraint(int i, double[][] dArr, double d);

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public boolean isOneStateConstraint() {
        return false;
    }

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public void setStateLabelMap(StateLabelMap stateLabelMap) {
        this.map = stateLabelMap;
    }

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public void preProcess(FeatureVector featureVector) {
        this.cache.clear();
        for (int i = 0; i < featureVector.numLocations(); i++) {
            int indexAtLocation = featureVector.indexAtLocation(i);
            if (this.constraintsMap.containsKey(indexAtLocation)) {
                this.cache.add(this.constraintsMap.get(indexAtLocation));
            }
        }
    }

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public BitSet preProcess(InstanceList instanceList) {
        BitSet bitSet = new BitSet(instanceList.size());
        int i = 0;
        Iterator<Instance> it = instanceList.iterator();
        while (it.hasNext()) {
            FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) it.next().getData();
            for (int i2 = 1; i2 < featureVectorSequence.size(); i2++) {
                Iterator it2 = this.constraintsMap.iterator();
                while (it2.hasNext()) {
                    IntIntCursor intIntCursor = (IntIntCursor) it2.next();
                    if (featureVectorSequence.get(i2).location(intIntCursor.key) >= 0) {
                        this.constraintsList.get(intIntCursor.value).count += 1.0d;
                        bitSet.set(i);
                    }
                }
            }
            i++;
        }
        return bitSet;
    }

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public double getCompositeConstraintFeatureValue(FeatureVector featureVector, int i, int i2, int i3) {
        if (i == 0) {
            return 0.0d;
        }
        double d = 0.0d;
        int labelIndex = this.map.getLabelIndex(i2);
        if (labelIndex == -2) {
            return 0.0d;
        }
        int labelIndex2 = this.map.getLabelIndex(i3);
        for (int i4 = 0; i4 < this.cache.size(); i4++) {
            d += this.constraintsList.get(this.cache.get(i4)).getValue(labelIndex, labelIndex2);
        }
        return d;
    }

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public abstract double getValue();

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public void zeroExpectations() {
        Iterator<TwoLabelGEConstraint> it = this.constraintsList.iterator();
        while (it.hasNext()) {
            it.next().expectation = new double[this.map.getNumLabels()][this.map.getNumLabels()];
        }
    }

    @Override // cc.mallet.fst.semi_supervised.constraints.GEConstraint
    public void computeExpectations(ArrayList<SumLattice> arrayList) {
        IntArrayList intArrayList = new IntArrayList();
        for (int i = 0; i < arrayList.size(); i++) {
            if (arrayList.get(i) != null) {
                FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) arrayList.get(i).getInput();
                double[][][] xis = arrayList.get(i).getXis();
                for (int i2 = 1; i2 < featureVectorSequence.size(); i2++) {
                    intArrayList.clear();
                    FeatureVector featureVector = featureVectorSequence.getFeatureVector(i2);
                    for (int i3 = 0; i3 < featureVector.numLocations(); i3++) {
                        int indexAtLocation = featureVector.indexAtLocation(i3);
                        if (this.constraintsMap.containsKey(indexAtLocation)) {
                            intArrayList.add(this.constraintsMap.get(indexAtLocation));
                        }
                    }
                    for (int i4 = 0; i4 < this.map.getNumStates(); i4++) {
                        int labelIndex = this.map.getLabelIndex(i4);
                        if (labelIndex != -2) {
                            for (int i5 = 0; i5 < this.map.getNumStates(); i5++) {
                                int labelIndex2 = this.map.getLabelIndex(i5);
                                if (labelIndex2 != -2) {
                                    double exp = Math.exp(xis[i2][i4][i5]);
                                    for (int i6 = 0; i6 < intArrayList.size(); i6++) {
                                        double[] dArr = this.constraintsList.get(intArrayList.get(i6)).expectation[labelIndex];
                                        dArr[labelIndex2] = dArr[labelIndex2] + exp;
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }
}
