/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform;

import java.util.Arrays;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.datatransform.InPlaceTransform;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.IndexFunction;
import jsat.math.OnLineStatistics;
import jsat.utils.DoubleList;

public class AutoDeskewTransform
implements InPlaceTransform {
    private static final long serialVersionUID = -4894242802345656448L;
    private double[] finalLambdas;
    private double[] mins;
    private final IndexFunction transform = new IndexFunction(){
        private static final long serialVersionUID = -404316813485246422L;

        @Override
        public double indexFunc(double value, int index) {
            if (index < 0) {
                return 0.0;
            }
            return AutoDeskewTransform.transform(value, AutoDeskewTransform.this.finalLambdas[index], AutoDeskewTransform.this.mins[index]);
        }
    };
    private static final DoubleList defaultList = new DoubleList(7);
    private List<Double> lambdas;
    private boolean ignorZeros;

    public AutoDeskewTransform() {
        this(true, (List<Double>)defaultList);
    }

    public AutoDeskewTransform(double ... lambdas) {
        this(true, (List<Double>)DoubleList.view(lambdas, lambdas.length));
    }

    public AutoDeskewTransform(List<Double> lambdas) {
        this(true, lambdas);
    }

    public AutoDeskewTransform(boolean ignorZeros, List<Double> lambdas) {
        this.ignorZeros = ignorZeros;
        this.lambdas = lambdas;
    }

    public AutoDeskewTransform(DataSet dataSet) {
        this(dataSet, (List<Double>)defaultList);
    }

    public AutoDeskewTransform(DataSet dataSet, List<Double> lambdas) {
        this(dataSet, true, lambdas);
    }

    public AutoDeskewTransform(DataSet dataSet, boolean ignorZeros, List<Double> lambdas) {
        this(ignorZeros, lambdas);
        this.fit(dataSet);
    }

    @Override
    public void fit(DataSet dataSet) {
        Vec x;
        int i;
        if (!this.lambdas.contains(1.0)) {
            this.lambdas.add(1.0);
        }
        OnLineStatistics[][] stats = new OnLineStatistics[this.lambdas.size()][dataSet.getNumNumericalVars()];
        for (int i2 = 0; i2 < stats.length; ++i2) {
            for (int j = 0; j < stats[i2].length; ++j) {
                stats[i2][j] = new OnLineStatistics();
            }
        }
        this.mins = new double[dataSet.getNumNumericalVars()];
        Arrays.fill(this.mins, Double.POSITIVE_INFINITY);
        boolean containsSparseVecs = false;
        for (i = 0; i < dataSet.getSampleSize(); ++i) {
            x = dataSet.getDataPoint(i).getNumericalValues();
            if (x.isSparse()) {
                containsSparseVecs = true;
            }
            for (IndexValue iv : x) {
                int indx = iv.getIndex();
                double val = iv.getValue();
                this.mins[indx] = Math.min(val, this.mins[indx]);
            }
        }
        if (containsSparseVecs) {
            for (i = 0; i < this.mins.length; ++i) {
                this.mins[i] = Math.min(0.0, this.mins[i]);
            }
        }
        for (i = 0; i < dataSet.getSampleSize(); ++i) {
            x = dataSet.getDataPoint(i).getNumericalValues();
            double weight = dataSet.getDataPoint(i).getWeight();
            int lastIndx = -1;
            for (IndexValue iv : x) {
                int indx = iv.getIndex();
                double val = iv.getValue();
                this.updateStats(this.lambdas, stats, indx, val, this.mins, weight);
                if (!this.ignorZeros) {
                    for (int prevIndx = lastIndx + 1; prevIndx < indx; ++prevIndx) {
                        this.updateStats(this.lambdas, stats, prevIndx, 0.0, this.mins, weight);
                    }
                }
                lastIndx = indx;
            }
            if (this.ignorZeros) continue;
            for (int prevIndx = lastIndx + 1; prevIndx < this.mins.length; ++prevIndx) {
                this.updateStats(this.lambdas, stats, prevIndx, 0.0, this.mins, weight);
            }
        }
        this.finalLambdas = new double[this.mins.length];
        int lambdaOneIndex = this.lambdas.indexOf(1.0);
        for (int d = 0; d < this.finalLambdas.length; ++d) {
            double minSkew = Double.POSITIVE_INFINITY;
            double bestLambda = 1.0;
            for (int k = 0; k < this.lambdas.size(); ++k) {
                double skew = Math.abs(stats[k][d].getSkewness());
                if (!(skew < minSkew)) continue;
                minSkew = skew;
                bestLambda = this.lambdas.get(k);
            }
            double origSkew = Math.abs(stats[lambdaOneIndex][d].getSkewness());
            this.finalLambdas[d] = origSkew > minSkew * 1.05 ? bestLambda : 1.0;
        }
    }

    protected AutoDeskewTransform(AutoDeskewTransform toCopy) {
        this.finalLambdas = Arrays.copyOf(toCopy.finalLambdas, toCopy.finalLambdas.length);
        this.mins = Arrays.copyOf(toCopy.mins, toCopy.mins.length);
    }

    private static double transform(double val, double lambda, double min) {
        if (val == 0.0) {
            return 0.0;
        }
        if (lambda == 2.0) {
            return val * val;
        }
        if (lambda == 1.0) {
            return val;
        }
        if (lambda == 0.5) {
            return Math.sqrt(val - min);
        }
        if (lambda == 0.0) {
            return Math.log(val + 1.0 - min);
        }
        if (lambda == -0.5) {
            return 1.0 / Math.sqrt(val - min);
        }
        if (lambda == -1.0) {
            return 1.0 / val;
        }
        if (lambda == -2.0) {
            return 1.0 / (val * val);
        }
        return Math.pow(val, lambda) / lambda;
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        DataPoint newDP = dp.clone();
        this.mutableTransform(newDP);
        return newDP;
    }

    @Override
    public void mutableTransform(DataPoint dp) {
        dp.getNumericalValues().applyIndexFunction(this.transform);
    }

    @Override
    public AutoDeskewTransform clone() {
        return new AutoDeskewTransform(this);
    }

    private void updateStats(List<Double> lambdas, OnLineStatistics[][] stats, int indx, double val, double[] mins, double weight) {
        for (int k = 0; k < lambdas.size(); ++k) {
            stats[k][indx].add(AutoDeskewTransform.transform(val, lambdas.get(k), mins[indx]), weight);
        }
    }

    @Override
    public boolean mutatesNominal() {
        return false;
    }

    static {
        defaultList.add(-1.0);
        defaultList.add(-0.5);
        defaultList.add(0.0);
        defaultList.add(0.5);
        defaultList.add(1.0);
    }
}

