/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.classification.meta;

import org.jdmp.core.algorithm.regression.AbstractRegressor;
import org.jdmp.core.algorithm.regression.Regressor;
import org.jdmp.core.dataset.DefaultListDataSet;
import org.jdmp.core.dataset.ListDataSet;
import org.jdmp.core.sample.Sample;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;

public class SemiSupervisedEM
extends AbstractRegressor {
    private static final long serialVersionUID = 7362798845466035645L;
    private final int iterations;
    private final Regressor algorithm;
    private final ListDataSet unlabeledData;
    private final boolean useRawPrediction;

    public SemiSupervisedEM(Regressor algorithm, ListDataSet unlabeledData, int iterations, boolean useRawPrediction) {
        this.setLabel("SemiSupervisedEM [" + algorithm.toString() + "]");
        this.algorithm = algorithm;
        this.unlabeledData = unlabeledData;
        this.useRawPrediction = useRawPrediction;
        this.iterations = iterations;
    }

    @Override
    public Matrix predictOne(Matrix input) {
        return this.algorithm.predictOne(input);
    }

    @Override
    public void reset() {
        this.algorithm.reset();
    }

    @Override
    public void trainAll(ListDataSet labeledData) {
        int classCount = this.getClassCount(labeledData);
        System.out.println("Step 0");
        this.algorithm.reset();
        this.algorithm.trainAll(labeledData);
        this.algorithm.predictAll(this.unlabeledData);
        for (Sample s : this.unlabeledData) {
            Matrix predicted = s.getAsMatrix("Predicted");
            if (this.useRawPrediction) {
                s.put("Target", predicted);
                continue;
            }
            int max = (int)predicted.indexOfMax(Calculation.Ret.NEW, 1).getAsDouble(0L, 0L);
            DenseMatrix target = Matrix.Factory.zeros(1L, (long)classCount);
            target.setAsDouble(1.0, 0L, max);
            s.put("Target", target);
        }
        for (int i = 0; i < this.iterations; ++i) {
            System.out.println("Step " + (i + 1));
            DefaultListDataSet completeData = new DefaultListDataSet();
            completeData.addAll(labeledData);
            completeData.addAll(this.unlabeledData);
            this.algorithm.reset();
            this.algorithm.trainAll(completeData);
            this.algorithm.predictAll(this.unlabeledData);
            for (Sample s : this.unlabeledData) {
                Matrix predicted = s.getAsMatrix("Predicted");
                if (this.useRawPrediction) {
                    s.put("Target", predicted);
                    continue;
                }
                int max = (int)predicted.indexOfMax(Calculation.Ret.NEW, 1).getAsDouble(0L, 0L);
                DenseMatrix target = Matrix.Factory.zeros(1L, (long)classCount);
                target.setAsDouble(1.0, 0L, max);
                s.put("Target", target);
            }
        }
    }

    @Override
    public Regressor emptyCopy() {
        return new SemiSupervisedEM(this.algorithm.emptyCopy(), this.unlabeledData, this.iterations, this.useRawPrediction);
    }
}

