/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.Arrays;
import smile.math.distance.Metric;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.Matrix;
import smile.math.matrix.QR;
import smile.math.rbf.GaussianRadialBasis;
import smile.math.rbf.RadialBasisFunction;
import smile.regression.Regression;
import smile.regression.RegressionTrainer;
import smile.util.SmileUtils;

public class RBFNetwork<T>
implements Regression<T>,
Serializable {
    private static final long serialVersionUID = 1L;
    private T[] centers;
    private double[] w;
    private Metric<T> distance;
    private RadialBasisFunction[] rbf;
    private boolean normalized;

    public RBFNetwork(T[] x, double[] y, Metric<T> distance, RadialBasisFunction rbf, T[] centers) {
        this(x, y, distance, rbf, centers, false);
    }

    public RBFNetwork(T[] x, double[] y, Metric<T> distance, RadialBasisFunction[] rbf, T[] centers) {
        this(x, y, distance, rbf, centers, false);
    }

    public RBFNetwork(T[] x, double[] y, Metric<T> distance, RadialBasisFunction rbf, T[] centers, boolean normalized) {
        this(x, y, distance, RBFNetwork.rep(rbf, centers.length), centers, normalized);
    }

    private static RadialBasisFunction[] rep(RadialBasisFunction rbf, int k) {
        Object[] arr = new RadialBasisFunction[k];
        Arrays.fill(arr, rbf);
        return arr;
    }

    public RBFNetwork(T[] x, double[] y, Metric<T> distance, RadialBasisFunction[] rbf, T[] centers, boolean normalized) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (rbf.length != centers.length) {
            throw new IllegalArgumentException(String.format("The sizes of RBF functions and centers don't match: %d != %d", rbf.length, centers.length));
        }
        this.centers = centers;
        this.distance = distance;
        this.rbf = rbf;
        this.normalized = normalized;
        int n = x.length;
        int m = rbf.length;
        DenseMatrix G = Matrix.zeros(n, m);
        double[] b = new double[n];
        for (int i = 0; i < n; ++i) {
            double sum = 0.0;
            for (int j = 0; j < m; ++j) {
                double r = rbf[j].f(distance.d(x[i], centers[j]));
                G.set(i, j, r);
                sum += r;
            }
            b[i] = normalized ? sum * y[i] : y[i];
        }
        this.w = new double[m];
        QR qr = G.qr();
        qr.solve(b, this.w);
    }

    @Override
    public double predict(T x) {
        double sum = 0.0;
        double sumw = 0.0;
        for (int i = 0; i < this.rbf.length; ++i) {
            double f = this.rbf[i].f(this.distance.d(x, this.centers[i]));
            sumw += this.w[i] * f;
            sum += f;
        }
        return this.normalized ? sumw / sum : sumw;
    }

    public static class Trainer<T>
    extends RegressionTrainer<T> {
        private int m = 10;
        private Metric<T> distance;
        private RadialBasisFunction[] rbf;
        private boolean normalized = false;

        public Trainer(Metric<T> distance) {
            this.distance = distance;
        }

        public Trainer<T> setRBF(RadialBasisFunction rbf, int m) {
            this.m = m;
            this.rbf = RBFNetwork.rep(rbf, m);
            return this;
        }

        public Trainer<T> setRBF(RadialBasisFunction[] rbf) {
            this.m = rbf.length;
            this.rbf = rbf;
            return this;
        }

        public Trainer<T> setNumCenters(int m) {
            this.m = m;
            return this;
        }

        public Trainer<T> setNormalized(boolean normalized) {
            this.normalized = normalized;
            return this;
        }

        @Override
        public RBFNetwork<T> train(T[] x, double[] y) {
            Object[] centers = (Object[])Array.newInstance(x.getClass().getComponentType(), this.m);
            GaussianRadialBasis gaussian = SmileUtils.learnGaussianRadialBasis(x, centers, this.distance);
            if (this.rbf == null) {
                return new RBFNetwork<Object>(x, y, this.distance, gaussian, centers, this.normalized);
            }
            return new RBFNetwork<Object>(x, y, this.distance, this.rbf, centers, this.normalized);
        }

        public RBFNetwork<T> train(T[] x, double[] y, T[] centers) {
            return new RBFNetwork<T>(x, y, this.distance, this.rbf, centers, this.normalized);
        }
    }
}

