/*
 * Decompiled with CFR 0.152.
 */
package mikera.matrixx.algo;

import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrix;
import mikera.matrixx.algo.IPLSResult;
import mikera.matrixx.algo.PseudoInverse;
import mikera.matrixx.impl.DiagonalMatrix;
import mikera.vectorz.AVector;
import mikera.vectorz.Vector;
import mikera.vectorz.impl.AStridedVector;

public class PLS
implements IPLSResult {
    private final AMatrix origX;
    private final Matrix X;
    private final Matrix Y;
    private final Matrix P;
    private final Matrix Q;
    private final Matrix T;
    private final Matrix U;
    private final Matrix W;
    private final Vector b;
    private final DiagonalMatrix B;
    private final Matrix coefficients;
    private final Vector constant;
    private final int l;
    private final int n;
    private final int m;
    private final int p;

    @Override
    public AMatrix getX() {
        return this.origX;
    }

    @Override
    public AMatrix getY() {
        return this.Y;
    }

    @Override
    public AMatrix getT() {
        return this.T;
    }

    @Override
    public AMatrix getP() {
        return this.P;
    }

    @Override
    public AMatrix getQ() {
        return this.Q;
    }

    @Override
    public AMatrix getW() {
        return this.W;
    }

    @Override
    public AMatrix getB() {
        return this.B;
    }

    @Override
    public AMatrix getCoefficients() {
        return this.coefficients;
    }

    @Override
    public AVector getConstant() {
        return this.constant;
    }

    private PLS(AMatrix X, AMatrix Y, int nFactors) {
        this.origX = X;
        this.Y = Matrix.create(Y);
        this.X = Matrix.create(this.origX);
        this.n = X.rowCount();
        this.m = X.columnCount();
        this.l = nFactors;
        this.p = Y.columnCount();
        if (Y.rowCount() != this.n) {
            throw new IllegalArgumentException("PLS regression requires equal number of rows in X annd Y matrices");
        }
        this.T = Matrix.create(this.n, this.l);
        this.U = Matrix.create(this.n, this.l);
        this.P = Matrix.create(this.m, this.l);
        this.Q = Matrix.create(this.p, this.l);
        this.W = Matrix.create(this.m, this.l);
        this.b = Vector.createLength(this.l);
        this.B = DiagonalMatrix.createDimensions(this.l);
        this.coefficients = Matrix.create(this.m, this.p);
        this.constant = Vector.createLength(this.p);
    }

    public static IPLSResult calculate(AMatrix X, AMatrix Y, int nFactors) {
        PLS pls = new PLS(X, Y, nFactors);
        pls.calcResult();
        return pls;
    }

    private int selectMaxSSColumn(AMatrix A) {
        int c = 0;
        double best = 0.0;
        for (int i = 0; i < this.m; ++i) {
            double ss = A.getColumn(i).elementSquaredSum();
            if (!(ss > best)) continue;
            c = i;
            best = ss;
        }
        return c;
    }

    private void calcResult() {
        Vector u = Vector.createLength(this.n);
        Vector w = Vector.createLength(this.m);
        Vector t = Vector.createLength(this.n);
        Vector t_old = Vector.createLength(this.n);
        Vector q = Vector.createLength(this.p);
        Vector pv = Vector.createLength(this.m);
        for (int j = 0; j < this.m; ++j) {
            AStridedVector col = this.X.getColumnView(j);
            double mean = ((AVector)col).elementSum() / (double)this.n;
            col.add(-mean);
        }
        for (int i = 0; i < this.l; ++i) {
            u.set(this.X.getColumn(this.selectMaxSSColumn(this.X)));
            int maxIterations = 10;
            int iterations = 0;
            while (iterations++ <= maxIterations) {
                w.setInnerProduct(u, this.X);
                w.normalise();
                t.setInnerProduct(this.X, w);
                t.normalise();
                q.setInnerProduct(t, this.Y);
                if (q.normalise() == 0.0) break;
                u.setInnerProduct(this.Y, q);
                double dist = t.distance(t_old);
                if (dist < 1.0E-11) break;
                t_old.set(t);
            }
            this.U.setColumn(i, u);
            this.W.setColumn(i, w);
            this.T.setColumn(i, t);
            this.Q.setColumn(i, q);
            this.b.set(i, t.dotProduct(u));
            pv.setInnerProduct(t, this.X);
            this.P.setColumn(i, pv);
            pv.negate();
            this.X.addOuterProduct(t, pv);
        }
        this.B.getLeadingDiagonal().set(this.b);
        AMatrix ptinv = PseudoInverse.calculate(this.P.getTranspose());
        this.coefficients.setInnerProduct(ptinv, this.B.innerProduct(this.Q.getTranspose()));
        this.constant.set(this.Q.getColumn(0));
        this.constant.addInnerProduct(this.P.getColumn(0), this.coefficients, -1.0);
    }
}

