package jsat.classifiers.bayesian.graphicalmodel;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.bayesian.ConditionalProbabilityTable;
import jsat.exceptions.FailedToFitException;
import jsat.utils.IntSet;

/* loaded from: input_file:jsat/classifiers/bayesian/graphicalmodel/DiscreteBayesNetwork.class */
public class DiscreteBayesNetwork implements Classifier {
    private static final long serialVersionUID = 2980734594356260141L;
    protected Map<Integer, ConditionalProbabilityTable> cpts;
    protected CategoricalData predicting;
    protected double[] priors;
    public static final boolean DEFAULT_USE_PRIORS = true;
    private boolean usePriors = true;
    protected DirectedGraph<Integer> dag = new DirectedGraph<>();

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(this.predicting.getNumOfCategories());
        int numCategoricalValues = dataPoint.numCategoricalValues();
        double d = 0.0d;
        double[] dArr = new double[categoricalResults.size()];
        for (int i = 0; i < categoricalResults.size(); i++) {
            DataPointPair<Integer> dataPointPair = new DataPointPair<>(dataPoint, Integer.valueOf(i));
            Iterator<Integer> it = this.dag.getChildren(Integer.valueOf(numCategoricalValues)).iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                int i2 = i;
                dArr[i2] = dArr[i2] + Math.log(this.cpts.get(Integer.valueOf(intValue)).query(intValue, dataPointPair));
            }
            if (this.usePriors) {
                int i3 = i;
                dArr[i3] = dArr[i3] + Math.log(this.priors[i]);
            }
            d += dArr[i];
        }
        for (int i4 = 0; i4 < categoricalResults.size(); i4++) {
            categoricalResults.setProb(i4, Math.exp(dArr[i4] - d));
        }
        return categoricalResults;
    }

    public void depends(int i, int i2) {
        this.dag.addNode(Integer.valueOf(i2));
        this.dag.addNode(Integer.valueOf(i));
        this.dag.addEdge(Integer.valueOf(i), Integer.valueOf(i2));
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        train(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet) {
        int numCategoricalVars = classificationDataSet.getNumCategoricalVars();
        if (numCategoricalVars == 0) {
            throw new FailedToFitException("Network needs categorical attribtues to work");
        }
        this.predicting = classificationDataSet.getPredicting();
        this.priors = classificationDataSet.getPriors();
        this.cpts = new HashMap();
        IntSet intSet = new IntSet();
        if (this.dag.getNodes().isEmpty()) {
            for (int i = 0; i < numCategoricalVars; i++) {
                depends(numCategoricalVars, i);
            }
        }
        Iterator<Integer> it = this.dag.getChildren(Integer.valueOf(numCategoricalVars)).iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            Set<Integer> children = this.dag.getChildren(Integer.valueOf(intValue));
            ConditionalProbabilityTable conditionalProbabilityTable = new ConditionalProbabilityTable();
            intSet.clear();
            intSet.addAll(children);
            intSet.add((IntSet) Integer.valueOf(intValue));
            intSet.add((IntSet) Integer.valueOf(numCategoricalVars));
            conditionalProbabilityTable.trainC(classificationDataSet, intSet);
            this.cpts.put(Integer.valueOf(intValue), conditionalProbabilityTable);
        }
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Classifier m18clone() {
        throw new UnsupportedOperationException("Not supported yet.");
    }
}
