package org.jdmp.core.algorithm.classification.meta;

import org.jdmp.core.algorithm.regression.AbstractRegressor;
import org.jdmp.core.algorithm.regression.Regressor;
import org.jdmp.core.dataset.DefaultListDataSet;
import org.jdmp.core.dataset.ListDataSet;
import org.jdmp.core.sample.Sample;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;

/* loaded from: input_file:org/jdmp/core/algorithm/classification/meta/SemiSupervisedEM.class */
public class SemiSupervisedEM extends AbstractRegressor {
    private static final long serialVersionUID = 7362798845466035645L;
    private final int iterations;
    private final Regressor algorithm;
    private final ListDataSet unlabeledData;
    private final boolean useRawPrediction;

    public SemiSupervisedEM(Regressor regressor, ListDataSet listDataSet, int i, boolean z) {
        setLabel("SemiSupervisedEM [" + regressor.toString() + "]");
        this.algorithm = regressor;
        this.unlabeledData = listDataSet;
        this.useRawPrediction = z;
        this.iterations = i;
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public Matrix predictOne(Matrix matrix) {
        return this.algorithm.predictOne(matrix);
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public void reset() {
        this.algorithm.reset();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jdmp.core.algorithm.regression.Regressor
    public void trainAll(ListDataSet listDataSet) {
        int classCount = getClassCount(listDataSet);
        System.out.println("Step 0");
        this.algorithm.reset();
        this.algorithm.trainAll(listDataSet);
        this.algorithm.predictAll(this.unlabeledData);
        for (Sample sample : this.unlabeledData) {
            Matrix asMatrix = sample.getAsMatrix("Predicted");
            if (this.useRawPrediction) {
                sample.put("Target", asMatrix);
            } else {
                int asDouble = (int) asMatrix.indexOfMax(Calculation.Ret.NEW, 1).getAsDouble(new long[]{0, 0});
                DenseMatrix zeros = Matrix.Factory.zeros(1L, classCount);
                zeros.setAsDouble(1.0d, new long[]{0, asDouble});
                sample.put("Target", zeros);
            }
        }
        for (int i = 0; i < this.iterations; i++) {
            System.out.println("Step " + (i + 1));
            ListDataSet defaultListDataSet = new DefaultListDataSet();
            defaultListDataSet.addAll(listDataSet);
            defaultListDataSet.addAll(this.unlabeledData);
            this.algorithm.reset();
            this.algorithm.trainAll(defaultListDataSet);
            this.algorithm.predictAll(this.unlabeledData);
            for (Sample sample2 : this.unlabeledData) {
                Matrix asMatrix2 = sample2.getAsMatrix("Predicted");
                if (this.useRawPrediction) {
                    sample2.put("Target", asMatrix2);
                } else {
                    int asDouble2 = (int) asMatrix2.indexOfMax(Calculation.Ret.NEW, 1).getAsDouble(new long[]{0, 0});
                    DenseMatrix zeros2 = Matrix.Factory.zeros(1L, classCount);
                    zeros2.setAsDouble(1.0d, new long[]{0, asDouble2});
                    sample2.put("Target", zeros2);
                }
            }
        }
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public Regressor emptyCopy() {
        return new SemiSupervisedEM(this.algorithm.emptyCopy(), this.unlabeledData, this.iterations, this.useRawPrediction);
    }
}
