/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear.kernelized;

import java.util.ArrayList;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.Uniform;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.IntList;

public class ALMA2K
extends BaseUpdateableClassifier
implements BinaryScoreClassifier,
Parameterized {
    private static final long serialVersionUID = 7247320234799227009L;
    private static final double p = 2.0;
    private double alpha;
    private double B;
    private double C = Math.sqrt(2.0);
    private int k;
    private int curRounds;
    @Parameter.ParameterHolder
    private KernelTrick K;
    private List<Vec> supports;
    private List<Double> signedEtas;
    private List<Double> associatedScores;
    private List<Double> normalizers;
    private List<Integer> rounds;
    private boolean averaged = false;

    public ALMA2K(KernelTrick kernel, double alpha) {
        this.setKernelTrick(kernel);
        this.setAlpha(alpha);
    }

    protected ALMA2K(ALMA2K other) {
        this.alpha = other.alpha;
        this.B = other.B;
        this.C = other.C;
        this.k = other.k;
        this.K = other.K.clone();
        this.averaged = other.averaged;
        if (other.supports != null) {
            this.supports = new ArrayList<Vec>(other.supports.size());
            for (Vec v : other.supports) {
                this.supports.add(v.clone());
            }
            this.signedEtas = new DoubleList(other.signedEtas);
            this.associatedScores = new DoubleList(other.associatedScores);
            this.normalizers = new DoubleList(other.normalizers);
            this.rounds = new IntList(other.rounds);
        }
    }

    @Override
    public ALMA2K clone() {
        return new ALMA2K(this);
    }

    public void setAveraged(boolean averaged) {
        this.averaged = averaged;
    }

    public boolean isAveraged() {
        return this.averaged;
    }

    public void setKernelTrick(KernelTrick K) {
        this.K = K;
    }

    public KernelTrick getKernelTrick() {
        return this.K;
    }

    public void setAlpha(double alpha) {
        if (alpha <= 0.0 || alpha > 1.0 || Double.isNaN(alpha)) {
            throw new ArithmeticException("alpha must be in (0, 1], not " + alpha);
        }
        this.alpha = alpha;
        this.setB(1.0 / alpha);
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setB(double B) {
        this.B = B;
    }

    public double getB() {
        return this.B;
    }

    public void setC(double C) {
        if (C <= 0.0 || Double.isInfinite(C) || Double.isNaN(C)) {
            throw new ArithmeticException("C must be a posative cosntant");
        }
        this.C = C;
    }

    public double getC() {
        return this.C;
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (numericAttributes <= 0) {
            throw new FailedToFitException("ALMA2 requires numeric features");
        }
        if (predicting.getNumOfCategories() != 2) {
            throw new FailedToFitException("ALMA2 works only for binary classification");
        }
        this.supports = new ArrayList<Vec>();
        this.signedEtas = new DoubleList();
        this.associatedScores = new DoubleList();
        this.normalizers = new DoubleList();
        this.rounds = new IntList();
        this.k = 1;
        this.curRounds = 0;
    }

    @Override
    public void update(DataPoint dataPoint, int targetClass) {
        Vec x_t = dataPoint.getNumericalValues();
        double y_t = targetClass * 2 - 1;
        double gamma = this.B * Math.sqrt(1.0) / (double)this.k;
        double wx = this.score(x_t, false);
        if (y_t * wx <= (1.0 - this.alpha) * gamma) {
            double eta = this.C / Math.sqrt(1.0) / Math.sqrt(this.k++);
            double norm = Math.sqrt(this.K.eval(x_t, x_t));
            this.associatedScores.add(this.score(new ScaledVector(1.0 / norm, x_t), false));
            this.supports.add(x_t);
            this.normalizers.add(norm);
            this.signedEtas.add(eta * y_t);
            this.rounds.add(this.curRounds);
            this.curRounds = 0;
        } else {
            ++this.curRounds;
        }
    }

    private double score(Vec x, boolean averaged) {
        double score = 0.0;
        double denom = 0.0;
        double finalScore = 0.0;
        for (int i = 0; i < this.supports.size(); ++i) {
            double eta_s = this.signedEtas.get(i);
            double tmp = eta_s * this.K.eval(this.supports.get(i), x) / this.normalizers.get(i);
            double denom_tmp = 2.0 * eta_s * this.associatedScores.get(i) + eta_s * eta_s;
            denom += denom / Math.max(1.0, denom) + denom_tmp;
            score += tmp / Math.max(1.0, denom);
            if (!averaged) continue;
            finalScore += score * (double)this.rounds.get(i).intValue();
        }
        if (averaged) {
            return finalScore;
        }
        return score;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        double wx = this.getScore(data);
        CategoricalResults cr = new CategoricalResults(2);
        if (wx < 0.0) {
            cr.setProb(0, 1.0);
        } else {
            cr.setProb(1, 1.0);
        }
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.score(dp.getNumericalValues(), this.averaged);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    public static Distribution guessAlpha(DataSet d) {
        return new Uniform(0.001, 1.0);
    }
}

