/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.Arrays;
import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.linear.StochasticSTLinearL1;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.regression.RegressionDataSet;
import jsat.utils.random.RandomUtil;

public class LinearL1SCD
extends StochasticSTLinearL1 {
    private static final long serialVersionUID = 3135562347568407186L;

    public LinearL1SCD() {
        this(1000, 1.0E-14, DEFAULT_LOSS);
    }

    public LinearL1SCD(int epochs, double lambda, StochasticSTLinearL1.Loss loss) {
        this(epochs, lambda, loss, true);
    }

    public LinearL1SCD(int epochs, double lambda, StochasticSTLinearL1.Loss loss, boolean reScale) {
        this.setEpochs(epochs);
        this.setLambda(lambda);
        this.setLoss(loss);
        this.setReScale(reScale);
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.w == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        Vec x = data.getNumericalValues();
        return this.loss.classify(this.wDot(x));
    }

    @Override
    public double regress(DataPoint data) {
        if (this.w == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        Vec x = data.getNumericalValues();
        return this.loss.regress(this.wDot(x));
    }

    private void featureScaleCheck(Vec[] featureVals, int m) throws FailedToFitException {
        if (this.reScale) {
            for (int j = 0; j < featureVals.length; ++j) {
                if (this.obvMin[j] == 0.0 && this.minScaled == 0.0) {
                    featureVals[j].mutableMultiply(this.maxScaled / this.obvMax[j]);
                } else {
                    featureVals[j].mutableSubtract(this.obvMin[j]);
                    featureVals[j].mutableMultiply((this.maxScaled - this.minScaled) / (this.obvMax[j] - this.obvMin[j]));
                    featureVals[j].mutableAdd(this.minScaled);
                }
                if (!featureVals[j].isSparse() || !((double)featureVals[j].nnz() > (double)m * 0.75)) continue;
                featureVals[j] = new DenseVector(featureVals[j]);
            }
        } else {
            for (int j = 0; j < this.obvMin.length; ++j) {
                if (!(this.obvMax[j] > 1.0) && !(this.obvMin[j] < -1.0)) continue;
                throw new FailedToFitException("All feature values must be in the range [-1,1]");
            }
        }
    }

    private void setUpFeatureVals(Vec[] featureVals, boolean sparse, int m, DataSet dataSet) {
        int i;
        this.obvMin = new double[featureVals.length];
        Arrays.fill(this.obvMin, Double.POSITIVE_INFINITY);
        this.obvMax = new double[featureVals.length];
        Arrays.fill(this.obvMax, Double.NEGATIVE_INFINITY);
        for (i = 0; i < featureVals.length; ++i) {
            featureVals[i] = sparse ? new SparseVector(m) : new DenseVector(m);
        }
        if (sparse) {
            Arrays.fill(this.obvMin, 0.0);
        }
        for (i = 0; i < dataSet.getSampleSize(); ++i) {
            Vec x = dataSet.getDataPoint(i).getNumericalValues();
            for (IndexValue iv : x) {
                int j = iv.getIndex();
                double v = iv.getValue();
                featureVals[j].set(i, v);
                this.obvMax[j] = Math.max(this.obvMax[j], v);
                this.obvMin[j] = Math.min(this.obvMin[j], v);
            }
        }
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        boolean sparse = dataSet.getDataPoint(0).getNumericalValues().isSparse();
        int m = dataSet.getSampleSize();
        Vec[] featureVals = new Vec[dataSet.getNumNumericalVars()];
        for (int i = 0; i < featureVals.length; ++i) {
            featureVals[i] = sparse ? new SparseVector(m) : new DenseVector(m);
        }
        this.setUpFeatureVals(featureVals, sparse, m, dataSet);
        this.featureScaleCheck(featureVals, m);
        double[] target = new double[m];
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            target[i] = dataSet.getTargetValue(i);
        }
        this.train(featureVals, target);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        if (dataSet.getClassSize() != 2) {
            throw new FailedToFitException("Only binary classification problems are supported");
        }
        boolean sparse = dataSet.getDataPoint(0).getNumericalValues().isSparse();
        int m = dataSet.getSampleSize();
        Vec[] featureVals = new Vec[dataSet.getNumNumericalVars()];
        this.setUpFeatureVals(featureVals, sparse, m, dataSet);
        this.featureScaleCheck(featureVals, m);
        double[] target = new double[m];
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            target[i] = dataSet.getDataPointCategory(i) * 2 - 1;
        }
        this.train(featureVals, target);
    }

    private void train(Vec[] featureVals, double[] target) {
        int d = featureVals.length;
        int m = target.length;
        this.w = new DenseVector(d);
        double[] z = new double[m];
        double beta = this.loss.beta();
        Random rand = RandomUtil.getRandom();
        for (int t = 1; t <= this.epochs; ++t) {
            double w_j;
            int j = rand.nextInt(d + 1);
            double g = 0.0;
            if (j < d) {
                Vec xj = featureVals[j];
                for (IndexValue iv : xj) {
                    int i = iv.getIndex();
                    g += this.loss.deriv(z[i], target[i]) * iv.getValue();
                }
            } else {
                for (int i = 0; i < target.length; ++i) {
                    g += this.loss.deriv(z[i], target[i]);
                }
            }
            double d2 = w_j = j == d ? this.bias : this.w.get(j);
            double eta = w_j - g / beta > this.lambda / beta ? -g / beta - this.lambda / beta : (w_j - g / beta < -this.lambda / beta ? -(g /= (double)m) / beta + this.lambda / beta : -w_j);
            if (j < d) {
                this.w.increment(j, eta);
            } else {
                this.bias += eta;
            }
            if (j < d) {
                for (IndexValue iv : featureVals[j]) {
                    int n = iv.getIndex();
                    z[n] = z[n] + eta * iv.getValue();
                }
                continue;
            }
            int i = 0;
            while (i < target.length) {
                int n = i++;
                z[n] = z[n] + eta;
            }
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public LinearL1SCD clone() {
        LinearL1SCD clone = new LinearL1SCD(this.epochs, this.lambda, this.loss, this.reScale);
        if (this.w != null) {
            clone.w = this.w.clone();
        }
        clone.bias = this.bias;
        clone.minScaled = this.minScaled;
        clone.maxScaled = this.maxScaled;
        if (this.obvMin != null) {
            clone.obvMin = Arrays.copyOf(this.obvMin, this.obvMin.length);
        }
        if (this.obvMax != null) {
            clone.obvMax = Arrays.copyOf(this.obvMax, this.obvMax.length);
        }
        return clone;
    }
}

