/*
 * Decompiled with CFR 0.152.
 */
package org.encog.ml.train.strategy.end;

import java.io.Serializable;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.end.EndTrainingStrategy;
import org.encog.util.obj.SerializeObject;
import org.encog.util.simple.EncogUtility;

public class EarlyStoppingStrategy
implements EndTrainingStrategy {
    private MLDataSet validationSet;
    private MLTrain train;
    private boolean stop;
    private double trainingError;
    private double lastValidationError;
    private MLRegression model;
    private int checkFrequency;
    private int lastCheck;
    private int allowedStagnantIterations;
    private int stagnantIterations;
    private MLRegression bestModel;
    private boolean saveBest;
    private double bestValidationError;
    private double minimumImprovement = 1.0E-13;

    public EarlyStoppingStrategy(MLDataSet theValidationSet) {
        this(theValidationSet, 5, 50);
    }

    public EarlyStoppingStrategy(MLDataSet theValidationSet, int theCheckFrequency, int theAllowedStagnantIterations) {
        this.validationSet = theValidationSet;
        this.checkFrequency = theCheckFrequency;
        this.allowedStagnantIterations = theAllowedStagnantIterations;
    }

    @Override
    public void init(MLTrain theTrain) {
        this.train = theTrain;
        this.model = (MLRegression)this.train.getMethod();
        this.stop = false;
        this.lastCheck = 0;
        this.lastValidationError = Double.POSITIVE_INFINITY;
    }

    @Override
    public void preIteration() {
    }

    @Override
    public void postIteration() {
        ++this.lastCheck;
        this.trainingError = this.train.getError();
        if (this.lastCheck > this.checkFrequency || Double.isInfinite(this.lastValidationError)) {
            double currentValidationError = EncogUtility.calculateRegressionError(this.model, this.validationSet);
            double improve = this.bestValidationError - currentValidationError;
            improve = Math.max(improve, 0.0);
            if (Double.isInfinite(currentValidationError) || Double.isNaN(currentValidationError)) {
                this.stop = true;
            } else if (this.bestValidationError <= currentValidationError && !Double.isInfinite(this.lastValidationError) && improve < this.minimumImprovement) {
                this.stagnantIterations += this.lastCheck;
                if (this.stagnantIterations > this.allowedStagnantIterations) {
                    this.stop = true;
                }
            } else {
                if (this.saveBest) {
                    this.bestModel = (MLRegression)((Object)SerializeObject.serializeClone((Serializable)((Object)this.model)));
                }
                this.bestValidationError = currentValidationError;
                this.stagnantIterations = 0;
            }
            this.lastValidationError = currentValidationError;
            this.lastCheck = 0;
        }
    }

    @Override
    public boolean shouldStop() {
        return this.stop;
    }

    public double getTrainingError() {
        return this.trainingError;
    }

    public double getValidationError() {
        return this.lastValidationError;
    }

    public int getStagnantIterations() {
        return this.stagnantIterations;
    }

    public void setStagnantIterations(int stagnantIterations) {
        this.stagnantIterations = stagnantIterations;
    }

    public int getAllowedStagnantIterations() {
        return this.allowedStagnantIterations;
    }

    public void setAllowedStagnantIterations(int allowedStagnantIterations) {
        this.allowedStagnantIterations = allowedStagnantIterations;
    }

    public boolean isSaveBest() {
        return this.saveBest;
    }

    public void setSaveBest(boolean saveBest) {
        this.saveBest = saveBest;
    }

    public MLRegression getBestModel() {
        return this.bestModel;
    }

    public double getBestValidationError() {
        return this.bestValidationError;
    }

    public double getMinimumImprovement() {
        return this.minimumImprovement;
    }

    public void setMinimumImprovement(double minimumImprovement) {
        this.minimumImprovement = minimumImprovement;
    }
}

