package cc.mallet.fst;

import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/fst/PerClassAccuracyEvaluator.class */
public class PerClassAccuracyEvaluator extends TransducerEvaluator {
    private static Logger logger;
    static final /* synthetic */ boolean $assertionsDisabled;

    public PerClassAccuracyEvaluator(InstanceList[] instanceListArr, String[] strArr) {
        super(instanceListArr, strArr);
    }

    public PerClassAccuracyEvaluator(InstanceList instanceList, String str) {
        this(new InstanceList[]{instanceList}, new String[]{str});
    }

    public PerClassAccuracyEvaluator(InstanceList instanceList, String str, InstanceList instanceList2, String str2) {
        this(new InstanceList[]{instanceList, instanceList2}, new String[]{str, str2});
    }

    @Override // cc.mallet.fst.TransducerEvaluator
    public void evaluateInstanceList(TransducerTrainer transducerTrainer, InstanceList instanceList, String str) {
        Transducer transducer = transducerTrainer.getTransducer();
        Alphabet targetAlphabet = transducer.getInputPipe().getTargetAlphabet();
        int size = targetAlphabet.size();
        int[] iArr = new int[size];
        int[] iArr2 = new int[size];
        int[] iArr3 = new int[size];
        logger.info("Per-token results for " + str);
        for (int i = 0; i < instanceList.size(); i++) {
            Instance instance = instanceList.get(i);
            Sequence sequence = (Sequence) instance.getData();
            Sequence sequence2 = (Sequence) instance.getTarget();
            if (!$assertionsDisabled && sequence.size() != sequence2.size()) {
                throw new AssertionError();
            }
            Sequence transduce = transducer.transduce(sequence);
            if (!$assertionsDisabled && transduce.size() != sequence2.size()) {
                throw new AssertionError();
            }
            for (int i2 = 0; i2 < sequence2.size(); i2++) {
                int lookupIndex = targetAlphabet.lookupIndex(sequence2.get(i2));
                iArr3[lookupIndex] = iArr3[lookupIndex] + 1;
                int lookupIndex2 = targetAlphabet.lookupIndex(transduce.get(i2));
                iArr2[lookupIndex2] = iArr2[lookupIndex2] + 1;
                if (sequence2.get(i2).equals(transduce.get(i2))) {
                    iArr[lookupIndex] = iArr[lookupIndex] + 1;
                }
            }
        }
        DecimalFormat decimalFormat = new DecimalFormat("0.####");
        double[] dArr = new double[size];
        double[] dArr2 = new double[size];
        double[] dArr3 = new double[size];
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i3 = 0; i3 < size; i3++) {
            Object lookupObject = targetAlphabet.lookupObject(i3);
            double d = iArr[i3] / iArr2[i3];
            double d2 = iArr[i3] / iArr3[i3];
            double d3 = ((2.0d * d) * d2) / (d + d2);
            if (!Double.isNaN(d)) {
                dArr[i3] = d;
                arrayList.add(Double.valueOf(d));
            }
            if (!Double.isNaN(d2)) {
                dArr2[i3] = d2;
                arrayList2.add(Double.valueOf(d2));
            }
            if (!Double.isNaN(d3)) {
                dArr3[i3] = d3;
                arrayList3.add(Double.valueOf(d3));
            }
            logger.info(str + " label " + lookupObject + " P " + decimalFormat.format(d) + " R " + decimalFormat.format(d2) + " F1 " + decimalFormat.format(d3));
        }
        logger.info("Macro-average (including non-used labels) P " + decimalFormat.format(MatrixOps.mean(dArr)) + " R " + decimalFormat.format(MatrixOps.mean(dArr2)) + " F " + decimalFormat.format(MatrixOps.mean(dArr3)));
        logger.info("Macro-average (excluding non-used labels) P " + decimalFormat.format(MatrixOps.mean(arrayList)) + " R " + decimalFormat.format(MatrixOps.mean(arrayList2)) + " F " + decimalFormat.format(MatrixOps.mean(arrayList3)));
    }

    static {
        $assertionsDisabled = !PerClassAccuracyEvaluator.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(PerClassAccuracyEvaluator.class.getName());
    }
}
