/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.ClassifierTrainer;
import smile.classification.OnlineClassifier;
import smile.classification.SoftClassifier;
import smile.math.Math;

public class NeuralNetwork
implements OnlineClassifier<double[]>,
SoftClassifier<double[]>,
Serializable {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = LoggerFactory.getLogger(NeuralNetwork.class);
    private ErrorFunction errorFunction = ErrorFunction.LEAST_MEAN_SQUARES;
    private ActivationFunction activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
    private int p;
    private int k;
    private Layer[] net;
    private Layer inputLayer;
    private Layer outputLayer;
    private double eta = 0.1;
    private double alpha = 0.0;
    private double lambda = 0.0;
    private double[] target;

    public NeuralNetwork(ErrorFunction error, int ... numUnits) {
        this(error, NeuralNetwork.natural(error, numUnits[numUnits.length - 1]), numUnits);
    }

    private static ActivationFunction natural(ErrorFunction error, int k) {
        if (error == ErrorFunction.CROSS_ENTROPY) {
            if (k == 1) {
                return ActivationFunction.LOGISTIC_SIGMOID;
            }
            return ActivationFunction.SOFTMAX;
        }
        return ActivationFunction.LOGISTIC_SIGMOID;
    }

    public NeuralNetwork(ErrorFunction error, ActivationFunction activation, int ... numUnits) {
        int i;
        int numLayers = numUnits.length;
        if (numLayers < 2) {
            throw new IllegalArgumentException("Invalid number of layers: " + numLayers);
        }
        for (i = 0; i < numLayers; ++i) {
            if (numUnits[i] >= 1) continue;
            throw new IllegalArgumentException(String.format("Invalid number of units of layer %d: %d", i + 1, numUnits[i]));
        }
        if (error == ErrorFunction.LEAST_MEAN_SQUARES && activation == ActivationFunction.SOFTMAX) {
            throw new IllegalArgumentException("Sofmax activation function is invalid for least mean squares error.");
        }
        if (error == ErrorFunction.CROSS_ENTROPY) {
            if (activation == ActivationFunction.LINEAR) {
                throw new IllegalArgumentException("Linear activation function is invalid with cross entropy error.");
            }
            if (activation == ActivationFunction.SOFTMAX && numUnits[numLayers - 1] == 1) {
                throw new IllegalArgumentException("Softmax activation function is for multi-class.");
            }
            if (activation == ActivationFunction.LOGISTIC_SIGMOID && numUnits[numLayers - 1] != 1) {
                throw new IllegalArgumentException("For cross entropy error, logistic sigmoid output is for binary classification.");
            }
        }
        this.errorFunction = error;
        this.activationFunction = activation;
        if (error == ErrorFunction.CROSS_ENTROPY) {
            this.alpha = 0.0;
            this.lambda = 0.0;
        }
        this.p = numUnits[0];
        this.k = numUnits[numLayers - 1] == 1 ? 2 : numUnits[numLayers - 1];
        this.target = new double[numUnits[numLayers - 1]];
        this.net = new Layer[numLayers];
        for (i = 0; i < numLayers; ++i) {
            this.net[i] = new Layer();
            this.net[i].units = numUnits[i];
            this.net[i].output = new double[numUnits[i] + 1];
            this.net[i].error = new double[numUnits[i] + 1];
            this.net[i].output[numUnits[i]] = 1.0;
        }
        this.inputLayer = this.net[0];
        this.outputLayer = this.net[numLayers - 1];
        for (int l = 1; l < numLayers; ++l) {
            this.net[l].weight = new double[numUnits[l]][numUnits[l - 1] + 1];
            this.net[l].delta = new double[numUnits[l]][numUnits[l - 1] + 1];
            double r = 1.0 / Math.sqrt(this.net[l - 1].units);
            for (int i2 = 0; i2 < this.net[l].units; ++i2) {
                for (int j = 0; j <= this.net[l - 1].units; ++j) {
                    this.net[l].weight[i2][j] = Math.random(-r, r);
                }
            }
        }
    }

    private NeuralNetwork() {
    }

    public NeuralNetwork clone() {
        NeuralNetwork copycat = new NeuralNetwork();
        copycat.errorFunction = this.errorFunction;
        copycat.activationFunction = this.activationFunction;
        copycat.p = this.p;
        copycat.k = this.k;
        copycat.eta = this.eta;
        copycat.alpha = this.alpha;
        copycat.lambda = this.lambda;
        copycat.target = (double[])this.target.clone();
        int numLayers = this.net.length;
        copycat.net = new Layer[numLayers];
        for (int i = 0; i < numLayers; ++i) {
            copycat.net[i] = new Layer();
            copycat.net[i].units = this.net[i].units;
            copycat.net[i].output = (double[])this.net[i].output.clone();
            copycat.net[i].error = (double[])this.net[i].error.clone();
            if (i <= 0) continue;
            copycat.net[i].weight = Math.clone(this.net[i].weight);
            copycat.net[i].delta = Math.clone(this.net[i].delta);
        }
        copycat.inputLayer = copycat.net[0];
        copycat.outputLayer = copycat.net[numLayers - 1];
        return copycat;
    }

    public void setLearningRate(double eta) {
        if (eta <= 0.0) {
            throw new IllegalArgumentException("Invalid learning rate: " + eta);
        }
        this.eta = eta;
    }

    public double getLearningRate() {
        return this.eta;
    }

    public void setMomentum(double alpha) {
        if (alpha < 0.0 || alpha >= 1.0) {
            throw new IllegalArgumentException("Invalid momentum factor: " + alpha);
        }
        this.alpha = alpha;
    }

    public double getMomentum() {
        return this.alpha;
    }

    public void setWeightDecay(double lambda) {
        if (lambda < 0.0 || lambda > 0.1) {
            throw new IllegalArgumentException("Invalid weight decay factor: " + lambda);
        }
        this.lambda = lambda;
    }

    public double getWeightDecay() {
        return this.lambda;
    }

    public double[][] getWeight(int layer) {
        return this.net[layer].weight;
    }

    private void setInput(double[] x) {
        if (x.length != this.inputLayer.units) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.inputLayer.units));
        }
        System.arraycopy(x, 0, this.inputLayer.output, 0, this.inputLayer.units);
    }

    private void getOutput(double[] y) {
        if (y.length != this.outputLayer.units) {
            throw new IllegalArgumentException(String.format("Invalid output vector size: %d, expected: %d", y.length, this.outputLayer.units));
        }
        System.arraycopy(this.outputLayer.output, 0, y, 0, this.outputLayer.units);
    }

    private void propagate(Layer lower, Layer upper) {
        for (int i = 0; i < upper.units; ++i) {
            double sum = 0.0;
            for (int j = 0; j <= lower.units; ++j) {
                sum += upper.weight[i][j] * lower.output[j];
            }
            if (upper != this.outputLayer || this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                upper.output[i] = Math.logistic(sum);
                continue;
            }
            if (this.activationFunction == ActivationFunction.LINEAR || this.activationFunction == ActivationFunction.SOFTMAX) {
                upper.output[i] = sum;
                continue;
            }
            throw new UnsupportedOperationException("Unsupported activation function.");
        }
        if (upper == this.outputLayer && this.activationFunction == ActivationFunction.SOFTMAX) {
            this.softmax();
        }
    }

    private void softmax() {
        int i;
        double max = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < this.outputLayer.units; ++i2) {
            if (!(this.outputLayer.output[i2] > max)) continue;
            max = this.outputLayer.output[i2];
        }
        double sum = 0.0;
        for (i = 0; i < this.outputLayer.units; ++i) {
            double out;
            this.outputLayer.output[i] = out = Math.exp(this.outputLayer.output[i] - max);
            sum += out;
        }
        i = 0;
        while (i < this.outputLayer.units) {
            int n = i++;
            this.outputLayer.output[n] = this.outputLayer.output[n] / sum;
        }
    }

    private void propagate() {
        for (int l = 0; l < this.net.length - 1; ++l) {
            this.propagate(this.net[l], this.net[l + 1]);
        }
    }

    private static double log(double x) {
        double y = 0.0;
        y = x < 1.0E-300 ? -690.7755 : Math.log(x);
        return y;
    }

    private double computeOutputError(double[] output) {
        return this.computeOutputError(output, this.outputLayer.error);
    }

    private double computeOutputError(double[] output, double[] gradient) {
        if (output.length != this.outputLayer.units) {
            throw new IllegalArgumentException(String.format("Invalid output vector size: %d, expected: %d", output.length, this.outputLayer.units));
        }
        double error = 0.0;
        for (int i = 0; i < this.outputLayer.units; ++i) {
            double out = this.outputLayer.output[i];
            double g = output[i] - out;
            if (this.errorFunction == ErrorFunction.LEAST_MEAN_SQUARES) {
                error += 0.5 * g * g;
            } else if (this.errorFunction == ErrorFunction.CROSS_ENTROPY) {
                if (this.activationFunction == ActivationFunction.SOFTMAX) {
                    error -= output[i] * NeuralNetwork.log(out);
                } else if (this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                    error = -output[i] * NeuralNetwork.log(out) - (1.0 - output[i]) * NeuralNetwork.log(1.0 - out);
                }
            }
            if (this.errorFunction == ErrorFunction.LEAST_MEAN_SQUARES && this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                g *= out * (1.0 - out);
            }
            gradient[i] = g;
        }
        return error;
    }

    private void backpropagate(Layer upper, Layer lower) {
        for (int i = 0; i <= lower.units; ++i) {
            double out = lower.output[i];
            double err = 0.0;
            for (int j = 0; j < upper.units; ++j) {
                err += upper.weight[j][i] * upper.error[j];
            }
            lower.error[i] = out * (1.0 - out) * err;
        }
    }

    private void backpropagate() {
        int l = this.net.length;
        while (--l > 0) {
            this.backpropagate(this.net[l], this.net[l - 1]);
        }
    }

    private void adjustWeights() {
        for (int l = 1; l < this.net.length; ++l) {
            for (int i = 0; i < this.net[l].units; ++i) {
                for (int j = 0; j <= this.net[l - 1].units; ++j) {
                    double delta;
                    double out = this.net[l - 1].output[j];
                    double err = this.net[l].error[i];
                    this.net[l].delta[i][j] = delta = (1.0 - this.alpha) * this.eta * err * out + this.alpha * this.net[l].delta[i][j];
                    double[] dArray = this.net[l].weight[i];
                    int n = j;
                    dArray[n] = dArray[n] + delta;
                    if (this.lambda == 0.0 || j >= this.net[l - 1].units) continue;
                    double[] dArray2 = this.net[l].weight[i];
                    int n2 = j;
                    dArray2[n2] = dArray2[n2] * (1.0 - this.eta * this.lambda);
                }
            }
        }
    }

    @Override
    public int predict(double[] x, double[] y) {
        this.setInput(x);
        this.propagate();
        this.getOutput(y);
        if (this.outputLayer.units == 1) {
            if (this.outputLayer.output[0] > 0.5) {
                return 0;
            }
            return 1;
        }
        double max = Double.NEGATIVE_INFINITY;
        int label = -1;
        for (int i = 0; i < this.outputLayer.units; ++i) {
            if (!(this.outputLayer.output[i] > max)) continue;
            max = this.outputLayer.output[i];
            label = i;
        }
        return label;
    }

    @Override
    public int predict(double[] x) {
        this.setInput(x);
        this.propagate();
        if (this.outputLayer.units == 1) {
            if (this.outputLayer.output[0] > 0.5) {
                return 0;
            }
            return 1;
        }
        double max = Double.NEGATIVE_INFINITY;
        int label = -1;
        for (int i = 0; i < this.outputLayer.units; ++i) {
            if (!(this.outputLayer.output[i] > max)) continue;
            max = this.outputLayer.output[i];
            label = i;
        }
        return label;
    }

    public double learn(double[] x, double[] y, double weight) {
        this.setInput(x);
        this.propagate();
        double err = weight * this.computeOutputError(y);
        if (weight != 1.0) {
            int i = 0;
            while (i < this.outputLayer.units) {
                int n = i++;
                this.outputLayer.error[n] = this.outputLayer.error[n] * weight;
            }
        }
        this.backpropagate();
        this.adjustWeights();
        return err;
    }

    @Override
    public void learn(double[] x, int y) {
        this.learn(x, y, 1.0);
    }

    public void learn(double[] x, int y, double weight) {
        if (weight < 0.0) {
            throw new IllegalArgumentException("Invalid weight: " + weight);
        }
        if (weight == 0.0) {
            logger.info("Ignore the training instance with zero weight.");
            return;
        }
        if (y < 0) {
            throw new IllegalArgumentException("Invalid class label: " + y);
        }
        if (this.outputLayer.units == 1 && y > 1) {
            throw new IllegalArgumentException("Invalid class label: " + y);
        }
        if (this.outputLayer.units > 1 && y >= this.outputLayer.units) {
            throw new IllegalArgumentException("Invalid class label: " + y);
        }
        if (this.errorFunction == ErrorFunction.CROSS_ENTROPY) {
            if (this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                this.target[0] = y == 0 ? 1.0 : 0.0;
            } else {
                for (int i = 0; i < this.target.length; ++i) {
                    this.target[i] = 0.0;
                }
                this.target[y] = 1.0;
            }
        } else {
            for (int i = 0; i < this.target.length; ++i) {
                this.target[i] = 0.1;
            }
            this.target[y] = 0.9;
        }
        this.learn(x, this.target, weight);
    }

    public void learn(double[][] x, int[] y) {
        int n = x.length;
        int[] index = Math.permutate(n);
        for (int i = 0; i < n; ++i) {
            this.learn(x[index[i]], y[index[i]]);
        }
    }

    public static class Trainer
    extends ClassifierTrainer<double[]> {
        private ErrorFunction errorFunction = ErrorFunction.LEAST_MEAN_SQUARES;
        private ActivationFunction activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
        private int[] numUnits;
        private double eta = 0.1;
        private double alpha = 0.0;
        private double lambda = 0.0;
        private int epochs = 25;

        public Trainer(ErrorFunction error, int ... numUnits) {
            this(error, NeuralNetwork.natural(error, numUnits[numUnits.length - 1]), numUnits);
        }

        public Trainer(ErrorFunction error, ActivationFunction activation, int ... numUnits) {
            int numLayers = numUnits.length;
            if (numLayers < 2) {
                throw new IllegalArgumentException("Invalid number of layers: " + numLayers);
            }
            for (int i = 0; i < numLayers; ++i) {
                if (numUnits[i] >= 1) continue;
                throw new IllegalArgumentException(String.format("Invalid number of units of layer %d: %d", i + 1, numUnits[i]));
            }
            if (error == ErrorFunction.LEAST_MEAN_SQUARES && activation == ActivationFunction.SOFTMAX) {
                throw new IllegalArgumentException("Sofmax activation function is invalid for least mean squares error.");
            }
            if (error == ErrorFunction.CROSS_ENTROPY) {
                if (activation == ActivationFunction.LINEAR) {
                    throw new IllegalArgumentException("Linear activation function is invalid with cross entropy error.");
                }
                if (activation == ActivationFunction.SOFTMAX && numUnits[numLayers - 1] == 1) {
                    throw new IllegalArgumentException("Softmax activation function is for multi-class.");
                }
                if (activation == ActivationFunction.LOGISTIC_SIGMOID && numUnits[numLayers - 1] != 1) {
                    throw new IllegalArgumentException("For cross entropy error, logistic sigmoid output is for binary classification.");
                }
            }
            this.errorFunction = error;
            this.activationFunction = activation;
            this.numUnits = numUnits;
        }

        public Trainer setLearningRate(double eta) {
            if (eta <= 0.0) {
                throw new IllegalArgumentException("Invalid learning rate: " + eta);
            }
            this.eta = eta;
            return this;
        }

        public Trainer setMomentum(double alpha) {
            if (alpha < 0.0 || alpha >= 1.0) {
                throw new IllegalArgumentException("Invalid momentum factor: " + alpha);
            }
            this.alpha = alpha;
            return this;
        }

        public Trainer setWeightDecay(double lambda) {
            if (lambda < 0.0 || lambda > 0.1) {
                throw new IllegalArgumentException("Invalid weight decay factor: " + lambda);
            }
            this.lambda = lambda;
            return this;
        }

        public Trainer setNumEpochs(int epochs) {
            if (epochs < 1) {
                throw new IllegalArgumentException("Invalid numer of epochs of stochastic learning:" + epochs);
            }
            this.epochs = epochs;
            return this;
        }

        public NeuralNetwork train(double[][] x, int[] y) {
            NeuralNetwork net = new NeuralNetwork(this.errorFunction, this.activationFunction, this.numUnits);
            net.setLearningRate(this.eta);
            net.setMomentum(this.alpha);
            net.setWeightDecay(this.lambda);
            for (int i = 1; i <= this.epochs; ++i) {
                net.learn(x, y);
                logger.info("Neural network learns epoch {}", (Object)i);
            }
            return net;
        }
    }

    private class Layer
    implements Serializable {
        private static final long serialVersionUID = 1L;
        int units;
        double[] output;
        double[] error;
        double[][] weight;
        double[][] delta;

        private Layer() {
        }
    }

    public static enum ActivationFunction {
        LINEAR,
        LOGISTIC_SIGMOID,
        SOFTMAX;

    }

    public static enum ErrorFunction {
        LEAST_MEAN_SQUARES,
        CROSS_ENTROPY;

    }
}

