/*
 * Decompiled with CFR 0.152.
 */
package jsat.distributions.kernels;

import java.util.List;
import jsat.DataSet;
import jsat.distributions.Distribution;
import jsat.distributions.kernels.BaseL2Kernel;
import jsat.distributions.kernels.GeneralRBFKernel;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.EuclideanDistance;

public class RBFKernel
extends BaseL2Kernel {
    private static final long serialVersionUID = -6733691081172950067L;
    private double sigma;
    private double sigmaSqrd2Inv;

    public RBFKernel() {
        this(1.0);
    }

    public RBFKernel(double sigma) {
        this.setSigma(sigma);
    }

    @Override
    public double eval(Vec a, Vec b) {
        if (a == b) {
            return 1.0;
        }
        return Math.exp(-Math.pow(a.pNormDist(2.0, b), 2.0) * this.sigmaSqrd2Inv);
    }

    @Override
    public double eval(int a, int b, List<? extends Vec> trainingSet, List<Double> cache) {
        if (a == b) {
            return 1.0;
        }
        return Math.exp(-this.getSqrdNorm(a, b, trainingSet, cache) * this.sigmaSqrd2Inv);
    }

    @Override
    public double eval(int a, Vec b, List<Double> qi, List<? extends Vec> vecs, List<Double> cache) {
        return Math.exp(-this.getSqrdNorm(a, b, qi, vecs, cache) * this.sigmaSqrd2Inv);
    }

    public void setSigma(double sigma) {
        if (sigma <= 0.0) {
            throw new IllegalArgumentException("Sigma must be a positive constant, not " + sigma);
        }
        this.sigma = sigma;
        this.sigmaSqrd2Inv = 0.5 / (sigma * sigma);
    }

    public double getSigma() {
        return this.sigma;
    }

    @Override
    public String toString() {
        return "RBF Kernel( \u03c3 = " + this.sigma + ")";
    }

    @Override
    public RBFKernel clone() {
        return new RBFKernel(this.sigma);
    }

    public static double sigmaToGamma(double sigma) {
        if (sigma <= 0.0 || Double.isNaN(sigma) || Double.isInfinite(sigma)) {
            throw new IllegalArgumentException("sigma must be positive, not " + sigma);
        }
        return 1.0 / (2.0 * sigma * sigma);
    }

    public static double gammToSigma(double gamma) {
        if (gamma <= 0.0 || Double.isNaN(gamma) || Double.isInfinite(gamma)) {
            throw new IllegalArgumentException("gamma must be positive, not " + gamma);
        }
        return 1.0 / Math.sqrt(2.0 * gamma);
    }

    public static Distribution guessSigma(DataSet d) {
        return GeneralRBFKernel.guessSigma(d, new EuclideanDistance());
    }

    @Override
    public boolean normalized() {
        return true;
    }
}

