/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.svm;

import java.util.Collections;
import java.util.Iterator;
import jsat.DataSet;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;
import jsat.linear.VecWithNorm;
import jsat.parameters.Parameterized;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

public class Pegasos
implements BinaryScoreClassifier,
Parameterized,
SingleWeightVectorModel {
    private static final long serialVersionUID = -2145631476467081171L;
    private int epochs;
    private double reg;
    private int batchSize;
    private boolean projectionStep = false;
    private Vec w;
    private double bias;
    public static final int DEFAULT_EPOCHS = 5;
    public static final double DEFAULT_REG = 1.0E-4;
    public static final int DEFAULT_BATCH_SIZE = 1;

    public Pegasos() {
        this(5, 1.0E-4, 1);
    }

    public Pegasos(int epochs, double reg, int batchSize) {
        this.setEpochs(epochs);
        this.setRegularization(reg);
        this.setBatchSize(batchSize);
    }

    public Pegasos(Pegasos toCopy) {
        this.epochs = toCopy.epochs;
        this.reg = toCopy.reg;
        this.batchSize = toCopy.batchSize;
        if (toCopy.w != null) {
            this.w = toCopy.w.clone();
        }
        this.bias = toCopy.bias;
        this.projectionStep = toCopy.projectionStep;
    }

    public void setBatchSize(int batchSize) {
        if (batchSize < 1) {
            throw new ArithmeticException("At least one sample must be take at each iteration");
        }
        this.batchSize = batchSize;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setEpochs(int epochs) {
        if (epochs < 1) {
            throw new ArithmeticException("Must perform a positive number of epochs");
        }
        this.epochs = epochs;
    }

    public double getEpochs() {
        return this.epochs;
    }

    public void setProjectionStep(boolean projectionStep) {
        this.projectionStep = projectionStep;
    }

    public boolean isProjectionStep() {
        return this.projectionStep;
    }

    public void setRegularization(double reg) {
        if (Double.isInfinite(reg) || Double.isNaN(reg) || reg <= 0.0) {
            throw new ArithmeticException("Pegasos requires a positive regularization cosntant");
        }
        this.reg = reg;
    }

    public double getRegularization() {
        return this.reg;
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return this.bias;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public Pegasos clone() {
        return new Pegasos(this);
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        if (this.getScore(data) < 0.0) {
            cr.setProb(0, 1.0);
        } else {
            cr.setProb(1, 1.0);
        }
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.w.dot(dp.getNumericalValues()) + this.bias;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        if (dataSet.getClassSize() != 2) {
            throw new FailedToFitException("SVM only supports binary classificaiton problems");
        }
        int m = dataSet.getSampleSize();
        this.w = new DenseVector(dataSet.getNumNumericalVars());
        if (this.projectionStep) {
            this.w = new VecWithNorm(this.w, 0.0);
        }
        this.w = new ScaledVector(this.w);
        this.bias = 0.0;
        IntList miniBatch = new IntList(this.batchSize);
        IntList randOrder = new IntList(m);
        ListUtils.addRange(randOrder, 0, m, 1);
        int t = 0;
        for (int epoch = 0; epoch < this.epochs; ++epoch) {
            Collections.shuffle(randOrder);
            for (int indx = 0; indx < m; indx += this.batchSize) {
                ++t;
                miniBatch.clear();
                miniBatch.addAll(randOrder.subList(indx, Math.min(indx + this.batchSize, m)));
                Iterator iter = miniBatch.iterator();
                while (iter.hasNext()) {
                    int i = (Integer)iter.next();
                    if (!(this.getSign(dataSet, i) * (this.w.dot(this.getX(dataSet, i)) + this.bias) >= 1.0)) continue;
                    iter.remove();
                }
                double nt = 1.0 / (this.reg * (double)t);
                this.w.mutableMultiply(1.0 - nt * this.reg);
                Iterator iterator = miniBatch.iterator();
                while (iterator.hasNext()) {
                    int i = (Integer)iterator.next();
                    double sign = this.getSign(dataSet, i);
                    Vec x = this.getX(dataSet, i);
                    double s = sign * nt / (double)this.batchSize;
                    this.w.mutableAdd(s, x);
                    this.bias += s;
                }
                if (!this.projectionStep) continue;
                double norm = this.w.pNorm(2.0);
                double mult = Math.min(1.0, 1.0 / (Math.sqrt(this.reg) * norm));
                this.w.mutableMultiply(mult);
                this.bias *= mult;
            }
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    private Vec getX(ClassificationDataSet dataSet, int i) {
        return dataSet.getDataPoint(i).getNumericalValues();
    }

    private double getSign(ClassificationDataSet dataSet, int i) {
        return dataSet.getDataPointCategory(i) == 1 ? 1.0 : -1.0;
    }

    public static Distribution guessRegularization(DataSet d) {
        return new LogUniform(1.0E-7, 0.01);
    }
}

