/*
 * Decompiled with CFR 0.152.
 */
package jhplot;

import java.awt.Dimension;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import jhplot.P0D;
import jhplot.PND;
import jhplot.SPlot;
import jhplot.gui.HelpBrowser;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.MLTrain;
import org.encog.neural.data.basic.BasicNeuralData;
import org.encog.neural.data.basic.BasicNeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.structure.AnalyzeNetwork;
import org.encog.neural.networks.training.propagation.back.Backpropagation;
import org.encog.util.obj.SerializeObject;
import org.encog.visualize.NetworkVisualizeFrame;
import org.encog.visualize.NetworkWeightsFrame;
import org.encog.workbench.EncogWorkBench;
import org.encog.workbench.frames.document.EncogDocumentFrame;

public class HNeuralNet {
    private BasicNetwork network = new BasicNetwork();
    private BasicNeuralDataSet data;
    private MLTrain train;
    private ArrayList<Double> EpochError;

    public void reset() {
        this.network.getStructure().finalizeStructure();
        this.network.reset();
    }

    public void addFeedForwardLayer(int neuronCount) {
        this.network.addLayer((Layer)new BasicLayer((ActivationFunction)new ActivationSigmoid(), false, neuronCount));
    }

    public void addFeedForwardLayerWithBias(int neuronCount) {
        this.network.addLayer((Layer)new BasicLayer((ActivationFunction)new ActivationSigmoid(), true, neuronCount));
    }

    public void setData(double[][] input, double[][] ideal) {
        this.data = new BasicNeuralDataSet(input, ideal);
    }

    public void setData(double[][] input) {
        this.data = new BasicNeuralDataSet(input, null);
    }

    public void setData(PND input, PND ideal) {
        this.data = new BasicNeuralDataSet(input.getArray(), ideal.getArray());
    }

    public void setData(PND input) {
        this.data = new BasicNeuralDataSet(input.getArray(), null);
    }

    public PND standardize(PND input) {
        return input.standardize();
    }

    public BasicNeuralDataSet getData() {
        return this.data;
    }

    public MLData predict(MLData input) {
        return this.network.compute(input);
    }

    public P0D predict(P0D input) {
        BasicNeuralData tmp = new BasicNeuralData(input.getArray());
        MLData output = this.network.compute((MLData)tmp);
        return new P0D("prediction", output.getData());
    }

    public PND predict(PND input) {
        PND tmp = new PND("Predicted");
        for (int i = 0; i < input.size(); ++i) {
            int j;
            P0D t = this.predict(input.getRow(i));
            double[] tt = t.getArray();
            if (tt.length == 1) {
                for (j = 0; j < tt.length; ++j) {
                    tmp.add(new double[]{tt[0]});
                }
                continue;
            }
            if (tt.length == 2) {
                for (j = 0; j < tt.length; ++j) {
                    tmp.add(new double[]{tt[0], tt[1]});
                }
                continue;
            }
            if (tt.length != 3) continue;
            for (j = 0; j < tt.length; ++j) {
                tmp.add(new double[]{tt[0], tt[1], tt[2]});
            }
        }
        return tmp;
    }

    public int trainBackpropagation(boolean isShow, int maxEpoch, double learnRate, double momentum, double errorMinEpoch) {
        SPlot plot = null;
        if (isShow) {
            plot = new SPlot();
            plot.visible();
            plot.setMarksStyle("various");
            plot.setConnected(true, 0);
            plot.setNameX("Epoch");
            plot.setNameY("Train error");
        }
        this.EpochError = new ArrayList();
        this.train = new Backpropagation((ContainsFlat)this.network, (MLDataSet)this.data, learnRate, momentum);
        int epoch = 1;
        do {
            this.train.iteration();
            double e = this.train.getError();
            this.EpochError.add(new Double(e));
            if (!isShow || ++epoch % 100 != 0) continue;
            plot.addPoint(0, epoch, e, true);
            plot.setAutoRange();
        } while (epoch < maxEpoch && this.train.getError() > errorMinEpoch);
        if (isShow) {
            plot.setAutoRange();
            plot.update();
        }
        return epoch;
    }

    public String save(String file) {
        try {
            SerializeObject.save((File)new File(file), (Serializable)this.network);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return "NN saved to " + file;
    }

    public int read(String file) {
        File f = new File(file);
        if (f.exists()) {
            try {
                this.network = (BasicNetwork)SerializeObject.load((File)f);
            }
            catch (ClassNotFoundException e) {
                e.printStackTrace();
            }
            catch (IOException e) {
                e.printStackTrace();
            }
            if (this.network == null) {
                return -2;
            }
        } else {
            return -1;
        }
        return 0;
    }

    public BasicNetwork getNetwork() {
        return this.network;
    }

    public void showNetwork() {
        NetworkVisualizeFrame frame = new NetworkVisualizeFrame(this.network);
        Dimension minSize = new Dimension(600, 500);
        frame.setMinimumSize(minSize);
        frame.setVisible(true);
    }

    public void showWeights() {
        NetworkWeightsFrame frame = new NetworkWeightsFrame(this.network);
        Dimension minSize = new Dimension(600, 500);
        frame.setMinimumSize(minSize);
        frame.setVisible(true);
    }

    public AnalyzeNetwork analyzeNetwork() {
        return new AnalyzeNetwork(this.network);
    }

    public BasicNetwork editNetwork() {
        return this.network;
    }

    public void show() {
        EncogWorkBench workBench = EncogWorkBench.getInstance();
        workBench.setMainWindow(new EncogDocumentFrame());
        workBench.init();
        workBench.getMainWindow().setVisible(true);
    }

    public ArrayList<Double> getEpochError() {
        return this.EpochError;
    }

    public void doc() {
        String a = this.getClass().getName();
        a = a.replace(".", "/") + ".html";
        new HelpBrowser("https://datamelt.org/api/doc.php/" + a);
    }
}

