package org.encog.workbench.tabs.training;

import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Font;
import java.awt.FontMetrics;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.File;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.concurrent.atomic.AtomicInteger;
import javax.swing.JButton;
import javax.swing.JComboBox;
import javax.swing.JPanel;
import org.encog.StatusReportable;
import org.encog.ml.MLError;
import org.encog.ml.MLMethod;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.ea.train.EvolutionaryAlgorithm;
import org.encog.ml.train.MLTrain;
import org.encog.util.Format;
import org.encog.util.file.FileUtil;
import org.encog.util.validate.ValidateNetwork;
import org.encog.workbench.EncogWorkBench;
import org.encog.workbench.frames.document.tree.ProjectEGFile;
import org.encog.workbench.tabs.EncogCommonTab;
import org.encog.workbench.util.EncogFonts;
import org.encog.workbench.util.TimeSpanFormatter;

/* loaded from: input_file:org/encog/workbench/tabs/training/BasicTrainingProgress.class */
public class BasicTrainingProgress extends EncogCommonTab implements Runnable, ActionListener, StatusReportable {
    private final JComboBox comboReset;
    private final JButton buttonStart;
    private final JButton buttonStop;
    private final JButton buttonClose;
    private final JButton buttonIteration;
    private final JPanel panelBody;
    private final JPanel panelButtons;
    private Thread thread;
    private boolean cancel;
    protected TrainingStatusPanel statusPanel;
    protected ChartPane chartPanel;
    private MLTrain train;
    private MLDataSet trainingData;
    private double maxError;
    private int iteration;
    private Font headFont;
    private Font bodyFont;
    private double currentError;
    private double lastError;
    private double errorImprovement;
    private Date started;
    private long lastUpdate;
    private final NumberFormat nf;
    private final NumberFormat nfShort;
    private int performanceCount;
    private Date performanceLast;
    private int performanceLastIteration;
    private String status;
    private boolean shouldExit;
    private AtomicInteger resetOption;
    private boolean error;
    private String lastMessage;
    private MLDataSet validationData;
    private double validationError;

    public BasicTrainingProgress(MLTrain mLTrain, ProjectEGFile projectEGFile, MLDataSet mLDataSet, MLDataSet mLDataSet2) {
        super(projectEGFile);
        this.nf = NumberFormat.getInstance();
        this.nfShort = NumberFormat.getInstance();
        this.resetOption = new AtomicInteger(-1);
        this.error = false;
        this.lastMessage = "";
        if (projectEGFile instanceof MLMethod) {
            ValidateNetwork.validateMethodToData((MLMethod) projectEGFile.getObject(), mLDataSet);
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add("<Select Option>");
        arrayList.add("Reset");
        arrayList.add("Perturb 1%");
        arrayList.add("Perturb 5%");
        arrayList.add("Perturb 10%");
        arrayList.add("Perturb 15%");
        arrayList.add("Perturb 20%");
        arrayList.add("Perturb 50%");
        this.comboReset = new JComboBox(arrayList.toArray());
        this.validationData = mLDataSet2;
        this.train = mLTrain;
        this.trainingData = mLDataSet;
        this.buttonStart = new JButton("Start");
        this.buttonStop = new JButton("Stop");
        this.buttonClose = new JButton("Close");
        this.buttonIteration = new JButton("Iteration");
        this.buttonStart.addActionListener(this);
        this.buttonStop.addActionListener(this);
        this.buttonClose.addActionListener(this);
        this.comboReset.addActionListener(this);
        this.buttonIteration.addActionListener(this);
        setLayout(new BorderLayout());
        this.panelBody = new JPanel();
        this.panelButtons = new JPanel();
        this.panelButtons.add(this.buttonStart);
        this.panelButtons.add(this.buttonStop);
        this.panelButtons.add(this.buttonClose);
        this.panelButtons.add(this.buttonIteration);
        this.panelButtons.add(this.comboReset);
        add(this.panelBody, "Center");
        add(this.panelButtons, "South");
        this.panelBody.setLayout(new BorderLayout());
        JPanel jPanel = this.panelBody;
        TrainingStatusPanel trainingStatusPanel = new TrainingStatusPanel(this);
        this.statusPanel = trainingStatusPanel;
        jPanel.add(trainingStatusPanel, "North");
        JPanel jPanel2 = this.panelBody;
        ChartPane chartPane = new ChartPane(this.validationData != null);
        this.chartPanel = chartPane;
        jPanel2.add(chartPane, "Center");
        this.buttonStop.setEnabled(false);
        this.shouldExit = false;
        this.bodyFont = EncogFonts.getInstance().getBodyFont();
        this.headFont = EncogFonts.getInstance().getHeadFont();
        this.status = "Ready to Start";
    }

    private void saveMLMethod() {
        ((ProjectEGFile) getEncogObject()).save(this.train.getMethod());
        if (getParentTab() != null) {
            getParentTab().setEncogObject(getEncogObject());
            getParentTab().refresh();
            getParentTab().setDirty(true);
            getParentTab().save();
        }
        if (this.train.canContinue()) {
            EncogWorkBench.getInstance().save(new File(FileUtil.forceExtension(String.valueOf(FileUtil.getFileName(getEncogObject().getFile())) + "-cont", "eg")), this.train.pause());
            EncogWorkBench.getInstance().refresh();
        }
    }

    private void performClose() {
        if (this.error) {
            return;
        }
        if (EncogWorkBench.askQuestion("Training", "Save the training?")) {
            if (getEncogObject() != null) {
                saveMLMethod();
            }
            EncogWorkBench.getInstance().refresh();
        } else if (getEncogObject() != null) {
            ((ProjectEGFile) getEncogObject()).revert();
        }
    }

    public void actionPerformed(ActionEvent actionEvent) {
        if (actionEvent.getSource() == this.buttonClose) {
            dispose();
            return;
        }
        if (actionEvent.getSource() == this.buttonStart) {
            performStart();
            return;
        }
        if (actionEvent.getSource() == this.buttonStop) {
            performStop();
            return;
        }
        if (actionEvent.getSource() == this.buttonIteration) {
            performIteration();
        } else if (actionEvent.getSource() == this.comboReset) {
            this.resetOption.set(this.comboReset.getSelectedIndex() - 1);
            this.comboReset.setSelectedIndex(0);
        }
    }

    @Override // org.encog.workbench.tabs.EncogCommonTab
    public boolean close() {
        if (this.thread == null) {
            performClose();
            return true;
        }
        this.shouldExit = true;
        this.cancel = true;
        return false;
    }

    public MLTrain getTrain() {
        return this.train;
    }

    public MLDataSet getTrainingData() {
        return this.trainingData;
    }

    public void paintStatus(Graphics graphics) {
        String str;
        graphics.setColor(Color.white);
        graphics.fillRect(0, 0, getWidth(), getHeight());
        graphics.setColor(Color.black);
        graphics.setFont(this.headFont);
        FontMetrics fontMetrics = graphics.getFontMetrics();
        int height = fontMetrics.getHeight();
        graphics.drawString("Iteration:", 10, height);
        int height2 = height + fontMetrics.getHeight();
        graphics.drawString("Current Error:", 10, height2);
        int height3 = height2 + fontMetrics.getHeight();
        graphics.drawString("Validation Error:", 10, height3);
        int height4 = height3 + fontMetrics.getHeight();
        graphics.drawString("Error Improvement:", 10, height4);
        graphics.drawString("Message:", 10, height4 + fontMetrics.getHeight());
        int height5 = fontMetrics.getHeight();
        graphics.drawString("Elapsed Time:", 400, height5);
        int height6 = height5 + fontMetrics.getHeight();
        graphics.drawString("Performance:", 400, height6);
        if (this.train instanceof EvolutionaryAlgorithm) {
            int height7 = height6 + fontMetrics.getHeight();
            graphics.drawString("Species Counts:", 400, height7);
            graphics.drawString("Best Genome Size:", 400, height7 + fontMetrics.getHeight());
        }
        int height8 = fontMetrics.getHeight();
        graphics.setFont(this.bodyFont);
        graphics.drawString(String.valueOf(this.nf.format(this.iteration)) + " (" + this.status + ")", 150, height8);
        int height9 = height8 + fontMetrics.getHeight();
        graphics.drawString(Format.formatPercent(this.currentError), 150, height9);
        int height10 = height9 + fontMetrics.getHeight();
        if (this.validationData != null) {
            graphics.drawString(Format.formatPercent(this.validationError), 150, height10);
        } else {
            graphics.drawString("n/a", 150, height10);
        }
        int height11 = height10 + fontMetrics.getHeight();
        graphics.drawString(Format.formatPercent(this.errorImprovement), 150, height11);
        graphics.drawString(this.lastMessage, 150, height11 + fontMetrics.getHeight());
        int height12 = fontMetrics.getHeight();
        long j = 0;
        if (this.started != null) {
            j = (new Date().getTime() - this.started.getTime()) / 1000;
        }
        graphics.drawString(TimeSpanFormatter.formatTime(j), 500, height12);
        int height13 = height12 + fontMetrics.getHeight();
        if (this.performanceCount == -1) {
            str = "  (calculating performance)";
        } else {
            str = "  (" + this.nfShort.format(this.performanceCount / 60.0d) + "/sec)";
        }
        graphics.drawString(str, 500, height13);
        if (this.train instanceof EvolutionaryAlgorithm) {
            int height14 = height13 + fontMetrics.getHeight();
            EvolutionaryAlgorithm evolutionaryAlgorithm = (EvolutionaryAlgorithm) this.train;
            if (evolutionaryAlgorithm.getPopulation() != null) {
                graphics.drawString(Format.formatInteger(evolutionaryAlgorithm.getPopulation().getSpecies().size()), 500, height14);
                int height15 = height14 + fontMetrics.getHeight();
                if (evolutionaryAlgorithm.getBestGenome() != null) {
                    graphics.drawString(Format.formatInteger(evolutionaryAlgorithm.getBestGenome().size()), 500, height15);
                }
            }
        }
    }

    private void performStart() {
        this.started = new Date();
        this.performanceLast = this.started;
        this.performanceCount = -1;
        this.performanceLastIteration = 0;
        this.buttonStart.setEnabled(false);
        this.buttonStop.setEnabled(true);
        this.buttonIteration.setEnabled(false);
        this.cancel = false;
        this.status = "Started";
        this.thread = new Thread(this);
        this.thread.start();
    }

    private void performStop() {
        this.status = "Canceled";
        this.cancel = true;
    }

    public void redraw() {
        this.statusPanel.repaint();
        this.lastUpdate = System.currentTimeMillis();
        this.chartPanel.addData(this.iteration, this.train.getError(), this.errorImprovement, this.validationError);
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Code restructure failed: missing block: B:33:0x016e, code lost:
    
        r6.resetOption.set(-1);
     */
    @Override // java.lang.Runnable
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public void run() {
        /*
            Method dump skipped, instructions count: 683
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.encog.workbench.tabs.training.BasicTrainingProgress.run():void");
    }

    public void setMaxError(double d) {
        this.maxError = d;
    }

    public void setTrain(MLTrain mLTrain) {
        this.train = mLTrain;
    }

    public void setTrainingData(MLDataSet mLDataSet) {
        this.trainingData = mLDataSet;
    }

    public void shutdown() {
    }

    public void startup() {
    }

    private void stopped() {
        this.thread = null;
        this.buttonIteration.setEnabled(true);
        this.buttonStart.setEnabled(true);
        this.buttonStop.setEnabled(false);
        this.cancel = true;
    }

    @Override // org.encog.workbench.tabs.EncogCommonTab
    public String getName() {
        return "Training Progress";
    }

    @Override // org.encog.StatusReportable
    public void report(int i, int i2, String str) {
        this.lastMessage = str;
        redraw();
    }

    public void performIteration() {
        for (int i = 0; i < EncogWorkBench.getInstance().getConfig().getIterationStepCount(); i++) {
            this.iteration++;
            this.lastError = this.train.getError();
            this.train.iteration();
            this.currentError = this.train.getError();
            this.errorImprovement = (this.lastError - this.currentError) / this.lastError;
            if (this.validationData != null) {
                MLMethod method = this.train.getMethod();
                if (method instanceof MLError) {
                    this.validationError = ((MLError) method).calculateError(this.validationData);
                }
            }
            if (Double.isInfinite(this.errorImprovement) || Double.isNaN(this.errorImprovement)) {
                this.errorImprovement = 100.0d;
            }
        }
        redraw();
    }
}
