/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.Random;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.lossfunctions.LossC;
import jsat.lossfunctions.LossFunc;
import jsat.lossfunctions.LossR;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.random.RandomUtil;

public class SCD
implements Classifier,
Regressor,
Parameterized,
SingleWeightVectorModel {
    private static final long serialVersionUID = 3576901723216525618L;
    private Vec w;
    private LossFunc loss;
    private double reg;
    private int iterations;

    public SCD(LossFunc loss, double regularization, int iterations) {
        double beta = loss.getDeriv2Max();
        if (Double.isNaN(beta) || Double.isInfinite(beta) || beta <= 0.0) {
            throw new IllegalArgumentException("SCD needs a loss function with a finite positive maximal second derivative");
        }
        this.loss = loss;
        this.setRegularization(regularization);
        this.setIterations(iterations);
    }

    public SCD(SCD toCopy) {
        this(toCopy.loss.clone(), toCopy.reg, toCopy.iterations);
        if (toCopy.w != null) {
            this.w = toCopy.w.clone();
        }
    }

    public void setIterations(int iterations) {
        if (iterations < 1) {
            throw new IllegalArgumentException("The iterations must be a positive value, not " + iterations);
        }
        this.iterations = iterations;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setRegularization(double regularization) {
        if (Double.isInfinite(regularization) || Double.isNaN(regularization) || regularization <= 0.0) {
            throw new IllegalArgumentException("Regularization must be a positive value");
        }
        this.reg = regularization;
    }

    public double getRegularization() {
        return this.reg;
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return 0.0;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.w != null && this.loss instanceof LossC) {
            return ((LossC)this.loss).getClassification(this.w.dot(data.getNumericalValues()));
        }
        throw new UntrainedModelException("Model was not trained with a classification function");
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        double[] targets = new double[dataSet.getSampleSize()];
        for (int i = 0; i < targets.length; ++i) {
            targets[i] = dataSet.getDataPointCategory(i) * 2 - 1;
        }
        this.train(dataSet.getNumericColumns(), targets);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public double regress(DataPoint data) {
        if (this.w != null && this.loss instanceof LossR) {
            return ((LossR)this.loss).getRegression(this.w.dot(data.getNumericalValues()));
        }
        throw new UntrainedModelException("Model was not trained with a classification function");
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train(dataSet.getNumericColumns(), dataSet.getTargetValues().arrayCopy());
    }

    private void train(Vec[] columns, double[] y) {
        double beta = this.loss.getDeriv2Max();
        double[] z = new double[y.length];
        this.w = new DenseVector(columns.length);
        Random rand = RandomUtil.getRandom();
        for (int iter = 0; iter < this.iterations; ++iter) {
            int j = rand.nextInt(columns.length);
            double g = 0.0;
            for (IndexValue iv : columns[j]) {
                g += this.loss.getDeriv(z[iv.getIndex()], y[iv.getIndex()]) * iv.getValue();
            }
            double w_j = this.w.get(j);
            double eta = w_j - g / beta > this.reg / beta ? -g / beta - this.reg / beta : (w_j - g / beta < -this.reg / beta ? -(g /= (double)y.length) / beta + this.reg / beta : -w_j);
            this.w.increment(j, eta);
            for (IndexValue iv : columns[j]) {
                int n = iv.getIndex();
                z[n] = z[n] + eta * iv.getValue();
            }
        }
    }

    @Override
    public SCD clone() {
        return new SCD(this);
    }
}

