/*
 * Decompiled with CFR 0.152.
 */
package smile.imputation;

import smile.imputation.KMeansImputation;
import smile.imputation.MissingValueImputation;
import smile.imputation.MissingValueImputationException;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.Matrix;
import smile.math.matrix.QR;
import smile.math.matrix.SVD;

public class SVDImputation
implements MissingValueImputation {
    private int k;

    public SVDImputation(int k) {
        if (k < 1) {
            throw new IllegalArgumentException("Invalid number of eigenvectors for imputation: " + k);
        }
        this.k = k;
    }

    @Override
    public void impute(double[][] data) throws MissingValueImputationException {
        this.impute(data, 10);
    }

    public void impute(double[][] data, int maxIter) throws MissingValueImputationException {
        int i;
        int i2;
        if (maxIter < 1) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int[] count = new int[data[0].length];
        for (i2 = 0; i2 < data.length; ++i2) {
            int n = 0;
            for (int j = 0; j < data[i2].length; ++j) {
                if (!Double.isNaN(data[i2][j])) continue;
                ++n;
                int n2 = j;
                count[n2] = count[n2] + 1;
            }
            if (n != data[i2].length) continue;
            throw new MissingValueImputationException("The whole row " + i2 + " is missing");
        }
        for (i2 = 0; i2 < data[0].length; ++i2) {
            if (count[i2] != data.length) continue;
            throw new MissingValueImputationException("The whole column " + i2 + " is missing");
        }
        double[][] full = new double[data.length][];
        for (i = 0; i < full.length; ++i) {
            full[i] = (double[])data[i].clone();
        }
        KMeansImputation.columnAverageImpute(full);
        for (int iter = 0; iter < maxIter; ++iter) {
            this.svdImpute(data, full);
        }
        for (i = 0; i < data.length; ++i) {
            System.arraycopy(full[i], 0, data[i], 0, data[i].length);
        }
    }

    private void svdImpute(double[][] raw, double[][] data) {
        SVD svd = Matrix.newInstance(data).svd();
        int d = data[0].length;
        for (int i = 0; i < raw.length; ++i) {
            int missing = 0;
            for (int j = 0; j < d; ++j) {
                if (Double.isNaN(raw[i][j])) {
                    ++missing;
                    continue;
                }
                data[i][j] = raw[i][j];
            }
            if (missing == 0) continue;
            DenseMatrix A = Matrix.zeros(d - missing, this.k);
            double[] b = new double[d - missing];
            int m = 0;
            for (int j = 0; j < d; ++j) {
                if (Double.isNaN(raw[i][j])) continue;
                for (int l = 0; l < this.k; ++l) {
                    A.set(m, l, svd.getV().get(j, l));
                }
                b[m++] = raw[i][j];
            }
            double[] s = new double[this.k];
            QR qr = A.qr();
            qr.solve(b, s);
            for (int j = 0; j < d; ++j) {
                if (!Double.isNaN(raw[i][j])) continue;
                data[i][j] = 0.0;
                for (int l = 0; l < this.k; ++l) {
                    double[] dArray = data[i];
                    int n = j;
                    dArray[n] = dArray[n] + s[l] * svd.getV().get(j, l);
                }
            }
        }
    }
}

