/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear.kernelized;

import java.util.List;
import jsat.DataSet;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.UpdateableClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.kernels.KernelPoint;
import jsat.distributions.kernels.KernelPoints;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.RBFKernel;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.lossfunctions.LossC;
import jsat.lossfunctions.LossFunc;
import jsat.lossfunctions.LossMC;
import jsat.lossfunctions.LossR;
import jsat.lossfunctions.SoftmaxLoss;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.BaseUpdateableRegressor;
import jsat.regression.RegressionDataSet;
import jsat.regression.UpdateableRegressor;

public class KernelSGD
implements UpdateableClassifier,
UpdateableRegressor,
Parameterized {
    private static final long serialVersionUID = -4956596506787859023L;
    private LossFunc loss;
    @Parameter.ParameterHolder
    private KernelTrick kernel;
    private double lambda;
    private double eta;
    private KernelPoint.BudgetStrategy budgetStrategy;
    private int budgetSize;
    private double errorTolerance;
    private int time;
    private KernelPoint kpoint;
    private KernelPoints kpoints;
    private int epochs = 1;

    public KernelSGD() {
        this(new SoftmaxLoss(), new RBFKernel(), 1.0E-4, KernelPoint.BudgetStrategy.MERGE_RBF, 300);
    }

    public KernelSGD(LossFunc loss, KernelTrick kernel, double lambda, KernelPoint.BudgetStrategy budgetStrategy, int budgetSize) {
        this(loss, kernel, lambda, budgetStrategy, budgetSize, 1.0, 0.05);
    }

    public KernelSGD(LossFunc loss, KernelTrick kernel, double lambda, KernelPoint.BudgetStrategy budgetStrategy, int budgetSize, double eta, double errorTolerance) {
        this.setLoss(loss);
        this.setKernel(kernel);
        this.setLambda(lambda);
        this.setEta(eta);
        this.setBudgetStrategy(budgetStrategy);
        this.setErrorTolerance(errorTolerance);
        this.setBudgetSize(budgetSize);
    }

    public KernelSGD(KernelSGD toCopy) {
        this.loss = toCopy.loss.clone();
        this.kernel = toCopy.kernel.clone();
        this.lambda = toCopy.lambda;
        this.eta = toCopy.eta;
        this.budgetStrategy = toCopy.budgetStrategy;
        this.budgetSize = toCopy.budgetSize;
        this.errorTolerance = toCopy.errorTolerance;
        this.time = toCopy.time;
        this.epochs = toCopy.epochs;
        if (toCopy.kpoint != null) {
            this.kpoint = toCopy.kpoint.clone();
        }
        if (toCopy.kpoints != null) {
            this.kpoints = toCopy.kpoints.clone();
        }
    }

    public void setEpochs(int epochs) {
        if (epochs < 1) {
            throw new IllegalArgumentException("Epochs must be a poistive constant, not " + epochs);
        }
        this.epochs = epochs;
    }

    public int getEpochs() {
        return this.epochs;
    }

    public void setLoss(LossFunc loss) {
        if (loss == null) {
            throw new NullPointerException("Loss may not be null");
        }
        this.loss = loss;
    }

    public LossFunc getLoss() {
        return this.loss;
    }

    public void setLambda(double lambda) {
        if (lambda <= 0.0 || Double.isNaN(lambda) || Double.isInfinite(lambda)) {
            throw new IllegalArgumentException("lambda must be a positive constant, not " + lambda);
        }
        this.lambda = lambda;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setErrorTolerance(double errorTolerance) {
        if (errorTolerance < 0.0 || errorTolerance > 1.0 || Double.isNaN(errorTolerance)) {
            throw new IllegalArgumentException("Error tolerance must be in [0, 1], not " + errorTolerance);
        }
        this.errorTolerance = errorTolerance;
    }

    public double getErrorTolerance() {
        return this.errorTolerance;
    }

    public void setBudgetSize(int budgetSize) {
        if (budgetSize < 1) {
            throw new IllegalArgumentException("Budgest size must be a positive constant, not " + budgetSize);
        }
        this.budgetSize = budgetSize;
    }

    public int getBudgetSize() {
        return this.budgetSize;
    }

    public void setBudgetStrategy(KernelPoint.BudgetStrategy budgetStrategy) {
        if (budgetStrategy == null) {
            throw new NullPointerException("Budget strategy must be non null");
        }
        this.budgetStrategy = budgetStrategy;
    }

    public KernelPoint.BudgetStrategy getBudgetStrategy() {
        return this.budgetStrategy;
    }

    public void setEta(double eta) {
        this.eta = eta;
    }

    public double getEta() {
        return this.eta;
    }

    public void setKernel(KernelTrick kernel) {
        if (kernel == null) {
            throw new NullPointerException("kernel trick must be non null");
        }
        this.kernel = kernel;
    }

    public KernelTrick getKernel() {
        return this.kernel;
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (!(this.loss instanceof LossC)) {
            throw new FailedToFitException("Loss in use (" + this.loss.getClass().getSimpleName() + ") does not support classification");
        }
        if (predicting.getNumOfCategories() == 2) {
            this.kpoint = new KernelPoint(this.kernel, this.errorTolerance);
            this.kpoint.setBudgetStrategy(this.budgetStrategy);
            this.kpoint.setErrorTolerance(this.errorTolerance);
            this.kpoint.setMaxBudget(this.budgetSize);
            this.kpoints = null;
        } else {
            if (!(this.loss instanceof LossMC)) {
                throw new FailedToFitException("Loss in use (" + this.loss.getClass().getSimpleName() + ") does not support multi-class classification");
            }
            this.kpoint = null;
            this.kpoints = new KernelPoints(this.kernel, predicting.getNumOfCategories(), this.errorTolerance);
            this.kpoints.setBudgetStrategy(this.budgetStrategy);
            this.kpoints.setErrorTolerance(this.errorTolerance);
            this.kpoints.setMaxBudget(this.budgetSize);
        }
        this.time = 0;
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes) {
        if (!(this.loss instanceof LossR)) {
            throw new FailedToFitException("Loss in use (" + this.loss.getClass().getSimpleName() + ") does not support regession");
        }
        this.kpoint = new KernelPoint(this.kernel, this.errorTolerance);
        this.kpoint.setBudgetStrategy(this.budgetStrategy);
        this.kpoint.setErrorTolerance(this.errorTolerance);
        this.kpoint.setMaxBudget(this.budgetSize);
        this.kpoints = null;
        this.time = 0;
    }

    @Override
    public void update(DataPoint dataPoint, int targetClass) {
        Vec x = dataPoint.getNumericalValues();
        List<Double> qi = this.kernel.getQueryInfo(x);
        double eta_t = this.getNextEta();
        if (this.kpoint != null) {
            this.kpoint.mutableMultiply(1.0 - eta_t * this.lambda);
            double y = targetClass * 2 - 1;
            double dot = this.kpoint.dot(x, qi);
            double lossD = ((LossC)this.loss).getDeriv(dot, y);
            if (lossD != 0.0) {
                this.kpoint.mutableAdd(-eta_t * lossD, x, qi);
            }
        } else if (this.kpoints != null) {
            this.kpoints.mutableMultiply(1.0 - eta_t * this.lambda);
            DenseVector pred = new DenseVector(this.kpoints.dot(x, qi));
            ((LossMC)this.loss).process(pred, pred);
            ((LossMC)this.loss).deriv(pred, pred, targetClass);
            ((Vec)pred).mutableMultiply(-eta_t);
            this.kpoints.mutableAdd(x, pred, qi);
        }
    }

    @Override
    public void update(DataPoint dataPoint, double targetValue) {
        Vec x = dataPoint.getNumericalValues();
        List<Double> qi = this.kernel.getQueryInfo(x);
        double eta_t = this.getNextEta();
        this.kpoint.mutableMultiply(1.0 - eta_t * this.lambda);
        double y = targetValue;
        double dot = this.kpoint.dot(x, qi);
        double lossD = ((LossR)this.loss).getDeriv(dot, y);
        if (lossD != 0.0) {
            this.kpoint.mutableAdd(-eta_t * lossD, x, qi);
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        Vec x = data.getNumericalValues();
        List<Double> qi = this.kernel.getQueryInfo(x);
        if (this.kpoint != null) {
            return ((LossC)this.loss).getClassification(this.kpoint.dot(x, qi));
        }
        DenseVector pred = new DenseVector(this.kpoints.dot(x, qi));
        ((LossMC)this.loss).process(pred, pred);
        return ((LossMC)this.loss).getClassification(pred);
    }

    @Override
    public double regress(DataPoint data) {
        Vec x = data.getNumericalValues();
        List<Double> qi = this.kernel.getQueryInfo(x);
        return ((LossR)this.loss).getRegression(this.kpoint.dot(x, qi));
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        BaseUpdateableClassifier.trainEpochs(dataSet, this, this.epochs);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        BaseUpdateableRegressor.trainEpochs(dataSet, this, this.epochs);
    }

    @Override
    public KernelSGD clone() {
        return new KernelSGD(this);
    }

    private double getNextEta() {
        return this.eta / (this.lambda * ((double)(++this.time) + 2.0 / this.lambda));
    }

    public static Distribution guessLambda(DataSet d) {
        return new LogUniform(1.0E-7, 0.01);
    }
}

