package jhplot;

import java.awt.Dimension;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import jhplot.gui.HelpBrowser;
import jyplot.BaseChartPanel;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLData;
import org.encog.ml.train.MLTrain;
import org.encog.neural.data.basic.BasicNeuralData;
import org.encog.neural.data.basic.BasicNeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.structure.AnalyzeNetwork;
import org.encog.neural.networks.training.propagation.back.Backpropagation;
import org.encog.util.obj.SerializeObject;
import org.encog.visualize.NetworkVisualizeFrame;
import org.encog.visualize.NetworkWeightsFrame;
import org.encog.workbench.EncogWorkBench;
import org.encog.workbench.frames.document.EncogDocumentFrame;

/* loaded from: input_file:jhplot/HNeuralNet.class */
public class HNeuralNet {
    private BasicNetwork network = new BasicNetwork();
    private BasicNeuralDataSet data;
    private MLTrain train;
    private ArrayList<Double> EpochError;

    public void reset() {
        this.network.getStructure().finalizeStructure();
        this.network.reset();
    }

    public void addFeedForwardLayer(int i) {
        this.network.addLayer(new BasicLayer(new ActivationSigmoid(), false, i));
    }

    public void addFeedForwardLayerWithBias(int i) {
        this.network.addLayer(new BasicLayer(new ActivationSigmoid(), true, i));
    }

    public void setData(double[][] dArr, double[][] dArr2) {
        this.data = new BasicNeuralDataSet(dArr, dArr2);
    }

    public void setData(double[][] dArr) {
        this.data = new BasicNeuralDataSet(dArr, (double[][]) null);
    }

    public void setData(PND pnd, PND pnd2) {
        this.data = new BasicNeuralDataSet(pnd.getArray(), pnd2.getArray());
    }

    public void setData(PND pnd) {
        this.data = new BasicNeuralDataSet(pnd.getArray(), (double[][]) null);
    }

    public PND standardize(PND pnd) {
        return pnd.standardize();
    }

    public BasicNeuralDataSet getData() {
        return this.data;
    }

    public MLData predict(MLData mLData) {
        return this.network.compute(mLData);
    }

    public P0D predict(P0D p0d) {
        return new P0D("prediction", this.network.compute(new BasicNeuralData(p0d.getArray())).getData());
    }

    public PND predict(PND pnd) {
        PND pnd2 = new PND("Predicted");
        for (int i = 0; i < pnd.size(); i++) {
            double[] array = predict(pnd.getRow(i)).getArray();
            if (array.length == 1) {
                for (int i2 = 0; i2 < array.length; i2++) {
                    pnd2.add(new double[]{array[0]});
                }
            } else if (array.length == 2) {
                for (int i3 = 0; i3 < array.length; i3++) {
                    pnd2.add(new double[]{array[0], array[1]});
                }
            } else if (array.length == 3) {
                for (int i4 = 0; i4 < array.length; i4++) {
                    pnd2.add(new double[]{array[0], array[1], array[2]});
                }
            }
        }
        return pnd2;
    }

    public int trainBackpropagation(boolean z, int i, double d, double d2, double d3) {
        SPlot sPlot = null;
        if (z) {
            sPlot = new SPlot();
            sPlot.visible();
            sPlot.setMarksStyle("various");
            sPlot.setConnected(true, 0);
            sPlot.setNameX("Epoch");
            sPlot.setNameY("Train error");
        }
        this.EpochError = new ArrayList<>();
        this.train = new Backpropagation(this.network, this.data, d, d2);
        int i2 = 1;
        do {
            this.train.iteration();
            double error = this.train.getError();
            this.EpochError.add(new Double(error));
            i2++;
            if (z && i2 % 100 == 0) {
                sPlot.addPoint(0, i2, error, true);
                sPlot.setAutoRange();
            }
            if (i2 >= i) {
                break;
            }
        } while (this.train.getError() > d3);
        if (z) {
            sPlot.setAutoRange();
            sPlot.update();
        }
        return i2;
    }

    public String save(String str) {
        try {
            SerializeObject.save(new File(str), this.network);
        } catch (IOException e) {
            e.printStackTrace();
        }
        return "NN saved to " + str;
    }

    public int read(String str) {
        File file = new File(str);
        if (!file.exists()) {
            return -1;
        }
        try {
            this.network = SerializeObject.load(file);
        } catch (IOException e) {
            e.printStackTrace();
        } catch (ClassNotFoundException e2) {
            e2.printStackTrace();
        }
        return this.network == null ? -2 : 0;
    }

    public BasicNetwork getNetwork() {
        return this.network;
    }

    public void showNetwork() {
        NetworkVisualizeFrame networkVisualizeFrame = new NetworkVisualizeFrame(this.network);
        networkVisualizeFrame.setMinimumSize(new Dimension(BaseChartPanel.DEFAULT_MAXIMUM_DRAW_HEIGHT, 500));
        networkVisualizeFrame.setVisible(true);
    }

    public void showWeights() {
        NetworkWeightsFrame networkWeightsFrame = new NetworkWeightsFrame(this.network);
        networkWeightsFrame.setMinimumSize(new Dimension(BaseChartPanel.DEFAULT_MAXIMUM_DRAW_HEIGHT, 500));
        networkWeightsFrame.setVisible(true);
    }

    public AnalyzeNetwork analyzeNetwork() {
        return new AnalyzeNetwork(this.network);
    }

    public BasicNetwork editNetwork() {
        return this.network;
    }

    public void show() {
        EncogWorkBench encogWorkBench = EncogWorkBench.getInstance();
        encogWorkBench.setMainWindow(new EncogDocumentFrame());
        encogWorkBench.init();
        encogWorkBench.getMainWindow().setVisible(true);
    }

    public ArrayList<Double> getEpochError() {
        return this.EpochError;
    }

    public void doc() {
        new HelpBrowser(HelpBrowser.JHPLOT_HTTP + (getClass().getName().replace(".", "/") + ".html"));
    }
}
