/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.svm;

import java.util.Collections;
import java.util.Iterator;
import java.util.Random;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.classifiers.svm.DCDs;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.random.RandomUtil;

public class DCD
implements BinaryScoreClassifier,
Regressor,
Parameterized,
SingleWeightVectorModel {
    private static final long serialVersionUID = -1489225034030922798L;
    private int maxIterations;
    private Vec[] vecs;
    private double[] alpha;
    private double[] y;
    private double bias;
    private Vec w;
    private double C;
    private boolean useL1;
    private boolean onlineVersion = false;
    private double eps = 0.001;
    private boolean useBias = true;

    public DCD() {
        this(10000, false);
    }

    public DCD(int maxIterations, boolean useL1) {
        this(maxIterations, 1.0, useL1);
    }

    public DCD(int maxIterations, double C, boolean useL1) {
        this.maxIterations = maxIterations;
        this.C = C;
        this.useL1 = useL1;
    }

    public void setOnlineVersion(boolean onlineVersion) {
        this.onlineVersion = onlineVersion;
    }

    public boolean isOnlineVersion() {
        return this.onlineVersion;
    }

    public void setEps(double eps) {
        if (Double.isNaN(eps) || eps < 0.0 || Double.isInfinite(eps)) {
            throw new IllegalArgumentException("eps must be non-negative, not " + eps);
        }
        this.eps = eps;
    }

    public double getEps() {
        return this.eps;
    }

    public void setC(double C) {
        if (Double.isNaN(C) || Double.isInfinite(C) || C <= 0.0) {
            throw new ArithmeticException("Penalty parameter must be a positive value, not " + C);
        }
        this.C = C;
    }

    public double getC() {
        return this.C;
    }

    public void setUseL1(boolean useL1) {
        this.useL1 = useL1;
    }

    public boolean isUseL1() {
        return this.useL1;
    }

    public void setMaxIterations(int maxIterations) {
        if (maxIterations <= 0) {
            throw new IllegalArgumentException("Number of iterations must be positive, not " + maxIterations);
        }
        this.maxIterations = maxIterations;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setUseBias(boolean useBias) {
        this.useBias = useBias;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return this.bias;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.w == null) {
            throw new UntrainedModelException("The model has not been trained");
        }
        CategoricalResults cr = new CategoricalResults(2);
        if (this.getScore(data) < 0.0) {
            cr.setProb(0, 1.0);
        } else {
            cr.setProb(1, 1.0);
        }
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.w.dot(dp.getNumericalValues()) + this.bias;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        if (dataSet.getClassSize() != 2) {
            throw new FailedToFitException("SVM only supports binary classificaiton problems");
        }
        this.vecs = new Vec[dataSet.getSampleSize()];
        this.alpha = new double[this.vecs.length];
        this.y = new double[this.vecs.length];
        this.bias = 0.0;
        double[] Qhs = new double[this.vecs.length];
        double U = this.getU();
        double D = this.getD();
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            this.vecs[i] = dataSet.getDataPoint(i).getNumericalValues();
            this.y[i] = dataSet.getDataPointCategory(i) * 2 - 1;
            Qhs[i] = this.vecs[i].dot(this.vecs[i]) + D;
            if (!this.useBias) continue;
            int n = i;
            Qhs[n] = Qhs[n] + 1.0;
        }
        this.w = new DenseVector(this.vecs[0].length());
        IntList A = new IntList(this.vecs.length);
        ListUtils.addRange(A, 0, this.vecs.length, 1);
        Random rand = RandomUtil.getRandom();
        for (int t = 0; t < this.maxIterations; ++t) {
            if (this.onlineVersion) {
                int i = rand.nextInt(this.vecs.length);
                this.performUpdate(i, D, U, Qhs[i]);
                continue;
            }
            Collections.shuffle(A, rand);
            Iterator iterator = A.iterator();
            while (iterator.hasNext()) {
                int i = (Integer)iterator.next();
                this.performUpdate(i, D, U, Qhs[i]);
            }
        }
    }

    private void performUpdate(int i, double D, double U, double Qh_ii) {
        double G = this.y[i] * (this.w.dot(this.vecs[i]) + this.bias) - 1.0 + D * this.alpha[i];
        double PG = this.alpha[i] == 0.0 ? Math.min(G, 0.0) : (this.alpha[i] == U ? Math.max(G, 0.0) : G);
        if (PG != 0.0) {
            double alphaOld = this.alpha[i];
            this.alpha[i] = Math.min(Math.max(this.alpha[i] - G / Qh_ii, 0.0), U);
            double scale = (this.alpha[i] - alphaOld) * this.y[i];
            this.w.mutableAdd(scale, this.vecs[i]);
            if (this.useBias) {
                this.bias += scale;
            }
        }
    }

    @Override
    public double regress(DataPoint data) {
        return this.w.dot(data.getNumericalValues()) + this.bias;
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.vecs = new Vec[dataSet.getSampleSize()];
        this.alpha = new double[this.vecs.length];
        this.y = new double[this.vecs.length];
        this.bias = 0.0;
        double[] Qhs = new double[this.vecs.length];
        double U = this.getU();
        double lambda = this.getD();
        double v_0 = 0.0;
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            this.vecs[i] = dataSet.getDataPoint(i).getNumericalValues();
            this.y[i] = dataSet.getTargetValue(i);
            Qhs[i] = this.vecs[i].dot(this.vecs[i]) + lambda;
            if (this.useBias) {
                int n = i;
                Qhs[n] = Qhs[n] + 1.0;
            }
            v_0 += Math.abs(DCDs.eq24(0.0, -this.y[i] - this.eps, -this.y[i] + this.eps, U));
        }
        this.w = new DenseVector(this.vecs[0].length());
        IntList activeSet = new IntList(this.vecs.length);
        ListUtils.addRange(activeSet, 0, this.vecs.length, 1);
        double M = Double.POSITIVE_INFINITY;
        for (int iteration = 0; iteration < this.maxIterations; ++iteration) {
            double maxVk = Double.NEGATIVE_INFINITY;
            double vKSum = 0.0;
            Collections.shuffle(activeSet);
            Iterator iter = activeSet.iterator();
            while (iter.hasNext()) {
                int i = (Integer)iter.next();
                double y_i = this.y[i];
                Vec x_i = this.vecs[i];
                double wDotX = this.w.dot(x_i) + this.bias;
                double g = -y_i + wDotX + lambda * this.alpha[i];
                double gP = g + this.eps;
                double gN = g - this.eps;
                double v_i = DCDs.eq24(this.alpha[i], gN, gP, U);
                maxVk = Math.max(maxVk, v_i);
                vKSum += Math.abs(v_i);
                double Q_ii = Qhs[i];
                double d = gP < Q_ii * this.alpha[i] ? -gP / Q_ii : (gN > Q_ii * this.alpha[i] ? -gN / Q_ii : -this.alpha[i]);
                if (Math.abs(d) < 1.0E-14) continue;
                double s = Math.max(-U, Math.min(U, this.alpha[i] + d));
                this.w.mutableAdd(s - this.alpha[i], x_i);
                if (this.useBias) {
                    this.bias += s - this.alpha[i];
                }
                this.alpha[i] = s;
            }
            if (vKSum / v_0 < 1.0E-4) break;
            M = maxVk;
        }
    }

    private double getU() {
        if (this.useL1) {
            return this.C;
        }
        return Double.POSITIVE_INFINITY;
    }

    private double getD() {
        if (this.useL1) {
            return 0.0;
        }
        return 1.0 / (2.0 * this.C);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public DCD clone() {
        DCD clone = new DCD(this.maxIterations, this.C, this.useL1);
        clone.onlineVersion = this.onlineVersion;
        clone.bias = this.bias;
        clone.useBias = this.useBias;
        if (this.w != null) {
            clone.w = this.w.clone();
        }
        return clone;
    }
}

