/*
 * Decompiled with CFR 0.152.
 */
package smile.demo.classification;

import java.awt.Dimension;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JOptionPane;
import javax.swing.JTextField;
import smile.classification.NeuralNetwork;
import smile.demo.classification.ClassificationDemo;
import smile.math.Math;

public class NeuralNetworkDemo
extends ClassificationDemo {
    private int units = 10;
    private int epochs = 20;
    private JTextField unitsField = new JTextField(Integer.toString(this.units), 5);
    private JTextField epochsField;

    public NeuralNetworkDemo() {
        this.optionPane.add(new JLabel("Hidden Neurons:"));
        this.optionPane.add(this.unitsField);
        this.epochsField = new JTextField(Integer.toString(this.epochs), 5);
        this.optionPane.add(new JLabel("Epochs:"));
        this.optionPane.add(this.epochsField);
    }

    @Override
    public double[][] learn(double[] x, double[] y) {
        int i;
        try {
            this.units = Integer.parseInt(this.unitsField.getText().trim());
            if (this.units <= 0) {
                JOptionPane.showMessageDialog(this, "Invalid number of hidden neurons: " + this.units, "Error", 0);
                return null;
            }
        }
        catch (Exception ex) {
            JOptionPane.showMessageDialog(this, "Invalid number of hidden neurons: " + this.unitsField.getText(), "Error", 0);
            return null;
        }
        try {
            this.epochs = Integer.parseInt(this.epochsField.getText().trim());
            if (this.epochs <= 0) {
                JOptionPane.showMessageDialog(this, "Invalid number of epochs: " + this.epochs, "Error", 0);
                return null;
            }
        }
        catch (Exception ex) {
            JOptionPane.showMessageDialog(this, "Invalid number of epochs: " + this.epochsField.getText(), "Error", 0);
            return null;
        }
        double[][] data = (double[][])dataset[datasetIndex].toArray((E[])new double[dataset[datasetIndex].size()][]);
        int[] label = dataset[datasetIndex].toArray(new int[dataset[datasetIndex].size()]);
        int k = Math.max(label) + 1;
        NeuralNetwork net = null;
        net = k == 2 ? new NeuralNetwork(NeuralNetwork.ErrorFunction.CROSS_ENTROPY, NeuralNetwork.ActivationFunction.LOGISTIC_SIGMOID, data[0].length, this.units, 1) : new NeuralNetwork(NeuralNetwork.ErrorFunction.CROSS_ENTROPY, NeuralNetwork.ActivationFunction.SOFTMAX, data[0].length, this.units, k);
        for (i = 0; i < this.epochs; ++i) {
            net.learn(data, label);
        }
        for (i = 0; i < label.length; ++i) {
            label[i] = net.predict(data[i]);
        }
        double trainError = this.error(label, label);
        System.out.format("training error = %.2f%%\n", 100.0 * trainError);
        double[][] z = new double[y.length][x.length];
        for (int i2 = 0; i2 < y.length; ++i2) {
            for (int j = 0; j < x.length; ++j) {
                double[] p = new double[]{x[j], y[i2]};
                z[i2][j] = net.predict(p);
            }
        }
        return z;
    }

    @Override
    public String toString() {
        return "Neural Network";
    }

    public static void main(String[] argv) {
        NeuralNetworkDemo demo = new NeuralNetworkDemo();
        JFrame f = new JFrame("Neural Network");
        f.setSize(new Dimension(1000, 1000));
        f.setLocationRelativeTo(null);
        f.setDefaultCloseOperation(3);
        f.getContentPane().add(demo);
        f.setVisible(true);
    }
}

