package jsat.datatransform;

import java.util.ArrayList;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.MatrixOfVecs;
import jsat.linear.SingularValueDecomposition;
import jsat.linear.Vec;

/* loaded from: input_file:jsat/datatransform/FastICA.class */
public class FastICA implements InvertibleTransform {
    private static final long serialVersionUID = -8644025740457515563L;
    private int C;
    private NegEntropyFunc G;
    private boolean preWhitened;
    private ZeroMeanTransform zeroMean;
    private Matrix unmixing;
    private Matrix mixing;

    /* loaded from: input_file:jsat/datatransform/FastICA$DefaultNegEntropyFunc.class */
    public enum DefaultNegEntropyFunc implements NegEntropyFunc {
        LOG_COSH { // from class: jsat.datatransform.FastICA.DefaultNegEntropyFunc.1
            @Override // jsat.datatransform.FastICA.DefaultNegEntropyFunc, jsat.datatransform.FastICA.NegEntropyFunc
            public double deriv1(double d) {
                return Math.tanh(d);
            }

            @Override // jsat.datatransform.FastICA.DefaultNegEntropyFunc, jsat.datatransform.FastICA.NegEntropyFunc
            public double deriv2(double d, double d2) {
                return 1.0d - (d2 * d2);
            }
        },
        EXP { // from class: jsat.datatransform.FastICA.DefaultNegEntropyFunc.2
            @Override // jsat.datatransform.FastICA.DefaultNegEntropyFunc, jsat.datatransform.FastICA.NegEntropyFunc
            public double deriv1(double d) {
                return d * Math.exp(((-d) * d) / 2.0d);
            }

            @Override // jsat.datatransform.FastICA.DefaultNegEntropyFunc, jsat.datatransform.FastICA.NegEntropyFunc
            public double deriv2(double d, double d2) {
                if (d == 0.0d) {
                    return 1.0d;
                }
                return (1.0d - (d * d)) * (d2 / d);
            }
        },
        KURTOSIS { // from class: jsat.datatransform.FastICA.DefaultNegEntropyFunc.3
            @Override // jsat.datatransform.FastICA.DefaultNegEntropyFunc, jsat.datatransform.FastICA.NegEntropyFunc
            public double deriv1(double d) {
                return d * d * d;
            }

            @Override // jsat.datatransform.FastICA.DefaultNegEntropyFunc, jsat.datatransform.FastICA.NegEntropyFunc
            public double deriv2(double d, double d2) {
                return d * d * 3.0d;
            }
        };

        @Override // jsat.datatransform.FastICA.NegEntropyFunc
        public abstract double deriv1(double d);

        @Override // jsat.datatransform.FastICA.NegEntropyFunc
        public abstract double deriv2(double d, double d2);
    }

    /* loaded from: input_file:jsat/datatransform/FastICA$NegEntropyFunc.class */
    public interface NegEntropyFunc {
        double deriv1(double d);

        double deriv2(double d, double d2);
    }

    public FastICA() {
        this(10);
    }

    public FastICA(int i) {
        this(i, DefaultNegEntropyFunc.LOG_COSH, false);
    }

    public FastICA(DataSet dataSet, int i) {
        this(dataSet, i, DefaultNegEntropyFunc.LOG_COSH, false);
    }

    public FastICA(int i, NegEntropyFunc negEntropyFunc, boolean z) {
        setC(i);
        setNegEntropyFunction(negEntropyFunc);
        setPreWhitened(z);
    }

    public FastICA(DataSet dataSet, int i, NegEntropyFunc negEntropyFunc, boolean z) {
        this(i, negEntropyFunc, z);
        fit(dataSet);
    }

    @Override // jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        Matrix dataMatrixView;
        int i;
        int sampleSize = dataSet.getSampleSize();
        DenseVector denseVector = new DenseVector(sampleSize);
        ArrayList arrayList = new ArrayList(this.C);
        WhitenedPCA whitenedPCA = null;
        if (this.preWhitened) {
            dataMatrixView = dataSet.getDataMatrixView();
        } else {
            this.zeroMean = new ZeroMeanTransform(dataSet);
            DataSet shallowClone2 = dataSet.shallowClone2();
            shallowClone2.applyTransform(this.zeroMean);
            whitenedPCA = new WhitenedPCA(shallowClone2);
            shallowClone2.applyTransform(whitenedPCA);
            dataMatrixView = shallowClone2.getDataMatrixView();
        }
        int cols = dataMatrixView.cols();
        DenseVector denseVector2 = new DenseVector(cols);
        for (int i2 = 0; i2 < this.C; i2++) {
            Vec random = Vec.random(cols);
            random.normalize();
            int i3 = 0;
            do {
                random.copyTo(denseVector2);
                denseVector.zeroOut();
                dataMatrixView.multiply(random, 1.0d, denseVector);
                double d = 0.0d;
                for (int i4 = 0; i4 < denseVector.length(); i4++) {
                    double d2 = denseVector.get(i4);
                    double deriv1 = this.G.deriv1(d2);
                    double deriv2 = this.G.deriv2(d2, deriv1);
                    if (Double.isNaN(deriv1) || Double.isInfinite(deriv1) || Double.isNaN(deriv2) || Double.isNaN(deriv2)) {
                        throw new FailedToFitException("Encountered NaN or Inf in calculation");
                    }
                    denseVector.set(i4, deriv1);
                    d += deriv2;
                }
                random.mutableMultiply(-(d / sampleSize));
                dataMatrixView.transposeMultiply(1.0d / sampleSize, denseVector, random);
                double[] dArr = new double[arrayList.size()];
                for (int i5 = 0; i5 < dArr.length; i5++) {
                    dArr[i5] = random.dot((Vec) arrayList.get(i5));
                }
                for (int i6 = 0; i6 < dArr.length; i6++) {
                    random.mutableAdd(-dArr[i6], (Vec) arrayList.get(i6));
                }
                random.normalize();
                if (Math.abs(1.0d - Math.abs(random.dot(denseVector2))) > 1.0E-6d) {
                    i = i3;
                    i3++;
                }
                arrayList.add(random);
            } while (i < 500);
            arrayList.add(random);
        }
        if (this.preWhitened) {
            this.unmixing = new DenseMatrix(new MatrixOfVecs(arrayList)).transpose();
        } else {
            this.unmixing = new MatrixOfVecs(arrayList).multiply(whitenedPCA.transform).transpose();
        }
        this.mixing = new SingularValueDecomposition(this.unmixing.mo171clone()).getPseudoInverse();
    }

    public FastICA(FastICA fastICA) {
        this.C = fastICA.C;
        this.G = fastICA.G;
        this.preWhitened = fastICA.preWhitened;
        if (fastICA.zeroMean != null) {
            this.zeroMean = fastICA.zeroMean.clone();
        }
        if (fastICA.unmixing != null) {
            this.unmixing = fastICA.unmixing.mo171clone();
        }
        if (fastICA.mixing != null) {
            this.mixing = fastICA.mixing.mo171clone();
        }
    }

    public void setC(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of components must be positive, not " + i);
        }
        this.C = i;
    }

    public int getC() {
        return this.C;
    }

    public void setNegEntropyFunction(NegEntropyFunc negEntropyFunc) {
        if (negEntropyFunc == null) {
            throw new NullPointerException("Negative Entropy function must be non-null");
        }
        this.G = negEntropyFunc;
    }

    public NegEntropyFunc getNegEntropyFunction() {
        return this.G;
    }

    public void setPreWhitened(boolean z) {
        this.preWhitened = z;
    }

    public boolean isPreWhitened() {
        return this.preWhitened;
    }

    @Override // jsat.datatransform.DataTransform
    public DataPoint transform(DataPoint dataPoint) {
        return new DataPoint((this.zeroMean != null ? this.zeroMean.transform(dataPoint).getNumericalValues() : dataPoint.getNumericalValues()).multiply(this.unmixing), dataPoint.getCategoricalValues(), dataPoint.getCategoricalData(), dataPoint.getWeight());
    }

    @Override // jsat.datatransform.InvertibleTransform
    public DataPoint inverse(DataPoint dataPoint) {
        DataPoint dataPoint2 = new DataPoint(dataPoint.getNumericalValues().multiply(this.mixing), dataPoint.getCategoricalValues(), dataPoint.getCategoricalData(), dataPoint.getWeight());
        if (this.zeroMean != null) {
            this.zeroMean.mutableInverse(dataPoint2);
        }
        return dataPoint2;
    }

    @Override // jsat.datatransform.InvertibleTransform, jsat.datatransform.DataTransform
    public FastICA clone() {
        return new FastICA(this);
    }
}
