/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform.kernel;

import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransformBase;
import jsat.distributions.Distribution;
import jsat.distributions.kernels.RBFKernel;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.RandomMatrix;
import jsat.linear.RandomVector;
import jsat.linear.Vec;
import jsat.utils.random.RandomUtil;

public class RFF_RBF
extends DataTransformBase {
    private static final long serialVersionUID = -3478216020648280477L;
    private Matrix transform;
    private Vec offsets;
    private double sigma;
    private int dim;
    private boolean inMemory;

    public RFF_RBF() {
        this(1.0);
    }

    public RFF_RBF(double sigma) {
        this(sigma, 512);
    }

    public RFF_RBF(double sigma, int dim) {
        this(sigma, dim, true);
    }

    public RFF_RBF(double sigma, int dim, boolean inMemory) {
        this.setSigma(sigma);
        this.setDimensions(dim);
        this.setInMemory(inMemory);
    }

    public RFF_RBF(int featurSize, double sigma, int dim, Random rand, boolean inMemory) {
        this(sigma, dim, inMemory);
        if (featurSize <= 0) {
            throw new IllegalArgumentException("The number of numeric features must be positive, not " + featurSize);
        }
        if (sigma <= 0.0 || Double.isInfinite(sigma) || Double.isNaN(sigma)) {
            throw new IllegalArgumentException("The sigma parameter must be positive, not " + sigma);
        }
        if (dim <= 1) {
            throw new IllegalArgumentException("The target dimension must be positive, not " + dim);
        }
        this.transform = new RandomMatrixRFF_RBF(Math.sqrt(0.5 / (sigma * sigma)), featurSize, dim, rand.nextLong());
        this.offsets = new RandomVectorRFF_RBF(dim, rand.nextLong());
        if (inMemory) {
            this.transform = this.transform.add(0.0);
            this.offsets = new DenseVector(this.offsets);
        }
    }

    @Override
    public void fit(DataSet data) {
        int featurSize = data.getNumNumericalVars();
        Random rand = RandomUtil.getRandom();
        this.transform = new RandomMatrixRFF_RBF(Math.sqrt(0.5 / (this.sigma * this.sigma)), featurSize, this.dim, rand.nextLong());
        this.offsets = new RandomVectorRFF_RBF(this.dim, rand.nextLong());
        if (this.inMemory) {
            this.transform = this.transform.add(0.0);
            this.offsets = new DenseVector(this.offsets);
        }
    }

    protected RFF_RBF(RFF_RBF toCopy) {
        if (toCopy.transform != null) {
            this.transform = toCopy.transform.clone();
        }
        if (toCopy.offsets != null) {
            this.offsets = toCopy.offsets.clone();
        }
        this.dim = toCopy.dim;
        this.inMemory = toCopy.inMemory;
        this.sigma = toCopy.sigma;
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        Vec oldX = dp.getNumericalValues();
        Vec newX = oldX.multiply(this.transform);
        double coef = Math.sqrt(2.0 / (double)this.transform.cols());
        for (int i = 0; i < newX.length(); ++i) {
            newX.set(i, Math.cos(newX.get(i) + this.offsets.get(i)) * coef);
        }
        return new DataPoint(newX, dp.getCategoricalValues(), dp.getCategoricalData(), dp.getWeight());
    }

    @Override
    public RFF_RBF clone() {
        return new RFF_RBF(this);
    }

    public void setInMemory(boolean inMemory) {
        this.inMemory = inMemory;
    }

    public boolean isInMemory() {
        return this.inMemory;
    }

    public void setDimensions(int dimensions) {
        if (dimensions < 1) {
            throw new ArithmeticException("Number of dimensions must be a positive value, not " + dimensions);
        }
        this.dim = dimensions;
    }

    public int getDimensions() {
        return this.dim;
    }

    public void setSigma(double sigma) {
        if (sigma <= 0.0 || Double.isInfinite(sigma) || Double.isNaN(sigma)) {
            throw new IllegalArgumentException("Sigma must be a positive value, not " + sigma);
        }
        this.sigma = sigma;
    }

    public double getSigma() {
        return this.sigma;
    }

    public Distribution guessSigma(DataSet d) {
        return RBFKernel.guessSigma(d);
    }

    private static class RandomVectorRFF_RBF
    extends RandomVector {
        private static final long serialVersionUID = -6132378281909907937L;

        public RandomVectorRFF_RBF(int length, long seedMult) {
            super(length, seedMult);
        }

        @Override
        protected double getVal(Random rand) {
            return rand.nextDouble() * 2.0 * Math.PI;
        }

        @Override
        public Vec clone() {
            return this;
        }
    }

    private static class RandomMatrixRFF_RBF
    extends RandomMatrix {
        private static final long serialVersionUID = 4702514384718636893L;
        private double coef;

        public RandomMatrixRFF_RBF(double coef, int rows, int cols, long seedMult) {
            super(rows, cols, seedMult);
            this.coef = coef;
        }

        @Override
        protected double getVal(Random rand) {
            return this.coef * rand.nextGaussian();
        }
    }
}

