package smile.imputation;

import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import smile.data.AttributeDataset;
import smile.data.parser.ArffParser;
import smile.data.parser.DelimitedTextParser;
import smile.data.parser.IOUtils;
import smile.math.Math;

/* loaded from: input_file:smile/imputation/MissingValueImputationTest.class */
public class MissingValueImputationTest {
    ArffParser arffParser = new ArffParser();
    DelimitedTextParser csvParser = new DelimitedTextParser();
    AttributeDataset movement;
    AttributeDataset control;
    AttributeDataset segment;
    AttributeDataset iris;
    AttributeDataset soybean;

    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
        try {
            this.arffParser.setResponseIndex(4);
            this.iris = this.arffParser.parse(IOUtils.getTestDataFile("weka/iris.arff"));
            this.arffParser.setResponseIndex(35);
            this.soybean = this.arffParser.parse(IOUtils.getTestDataFile("weka/soybean.arff"));
            this.arffParser.setResponseIndex(19);
            this.segment = this.arffParser.parse(IOUtils.getTestDataFile("weka/segment-challenge.arff"));
            this.csvParser.setDelimiter(",");
            this.movement = this.csvParser.parse("Movement", IOUtils.getTestDataFile("uci/movement_libras.data"));
            this.csvParser.setDelimiter(" +");
            this.control = this.csvParser.parse("Control", IOUtils.getTestDataFile("uci/synthetic_control.data"));
        } catch (Exception e) {
            System.err.println(e);
        }
    }

    @After
    public void tearDown() {
    }

    private double impute(AttributeDataset attributeDataset, MissingValueImputation missingValueImputation, double d) throws Exception {
        int i = 0;
        double[][] array = attributeDataset.toArray((Object[]) new double[attributeDataset.size()]);
        double[][] dArr = new double[array.length][array[0].length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            for (int i3 = 0; i3 < dArr[i2].length; i3++) {
                if (Math.random() < d) {
                    i++;
                    dArr[i2][i3] = Double.NaN;
                } else {
                    dArr[i2][i3] = array[i2][i3];
                }
            }
        }
        missingValueImputation.impute(dArr);
        double d2 = 0.0d;
        for (int i4 = 0; i4 < dArr.length; i4++) {
            for (int i5 = 0; i5 < dArr[i4].length; i5++) {
                double d3 = array[i4][i5] - dArr[i4][i5];
                d2 += d3 * d3;
            }
        }
        return Math.sqrt(d2 / i);
    }

    void impute(AttributeDataset attributeDataset) throws Exception {
        int length = attributeDataset.attributes().length;
        System.out.println("----------- " + attributeDataset.getName() + " ----------------");
        System.out.println("----------- " + attributeDataset.size() + " x " + length + " ----------------");
        System.out.println("MeanImputation");
        AverageImputation averageImputation = new AverageImputation();
        System.out.println("RMSE of 1% missing values = " + impute(attributeDataset, averageImputation, 0.01d));
        System.out.println("RMSE of 5% missing values = " + impute(attributeDataset, averageImputation, 0.05d));
        System.out.println("RMSE of 10% missing values = " + impute(attributeDataset, averageImputation, 0.1d));
        System.out.println("RMSE of 15% missing values = " + impute(attributeDataset, averageImputation, 0.15d));
        System.out.println("RMSE of 20% missing values = " + impute(attributeDataset, averageImputation, 0.2d));
        System.out.println("RMSE of 25% missing values = " + impute(attributeDataset, averageImputation, 0.25d));
        System.out.println("KMeansImputation");
        KMeansImputation kMeansImputation = new KMeansImputation(10, 5);
        System.out.println("RMSE of 1% missing values = " + impute(attributeDataset, kMeansImputation, 0.01d));
        System.out.println("RMSE of 5% missing values = " + impute(attributeDataset, kMeansImputation, 0.05d));
        System.out.println("RMSE of 10% missing values = " + impute(attributeDataset, kMeansImputation, 0.1d));
        System.out.println("RMSE of 15% missing values = " + impute(attributeDataset, kMeansImputation, 0.15d));
        System.out.println("RMSE of 20% missing values = " + impute(attributeDataset, kMeansImputation, 0.2d));
        System.out.println("RMSE of 25% missing values = " + impute(attributeDataset, kMeansImputation, 0.25d));
        System.out.println("KNNImputation");
        KNNImputation kNNImputation = new KNNImputation(10);
        System.out.println("RMSE of 1% missing values = " + impute(attributeDataset, kNNImputation, 0.01d));
        System.out.println("RMSE of 5% missing values = " + impute(attributeDataset, kNNImputation, 0.05d));
        System.out.println("RMSE of 10% missing values = " + impute(attributeDataset, kNNImputation, 0.1d));
        System.out.println("RMSE of 15% missing values = " + impute(attributeDataset, kNNImputation, 0.15d));
        System.out.println("RMSE of 20% missing values = " + impute(attributeDataset, kNNImputation, 0.2d));
        System.out.println("RMSE of 25% missing values = " + impute(attributeDataset, kNNImputation, 0.25d));
        if (length > 15) {
            System.out.println("SVDImputation");
            SVDImputation sVDImputation = new SVDImputation(length / 5);
            System.out.println("RMSE of 1% missing values = " + impute(attributeDataset, sVDImputation, 0.01d));
            System.out.println("RMSE of 5% missing values = " + impute(attributeDataset, sVDImputation, 0.05d));
            System.out.println("RMSE of 10% missing values = " + impute(attributeDataset, sVDImputation, 0.1d));
        }
        if (length > 15) {
            System.out.println("LLSImputation");
            LLSImputation lLSImputation = new LLSImputation(10);
            System.out.println("RMSE of 1% missing values = " + impute(attributeDataset, lLSImputation, 0.01d));
            System.out.println("RMSE of 5% missing values = " + impute(attributeDataset, lLSImputation, 0.05d));
            System.out.println("RMSE of 10% missing values = " + impute(attributeDataset, lLSImputation, 0.1d));
            System.out.println("RMSE of 15% missing values = " + impute(attributeDataset, lLSImputation, 0.15d));
            System.out.println("RMSE of 20% missing values = " + impute(attributeDataset, lLSImputation, 0.2d));
            System.out.println("RMSE of 25% missing values = " + impute(attributeDataset, lLSImputation, 0.25d));
        }
    }

    @Test
    public void testImpute() throws Exception {
        impute(this.segment);
        impute(this.movement);
        impute(this.control);
    }
}
