package org.encog.ensemble.training;

import org.encog.ensemble.EnsembleTrainFactory;
import org.encog.ml.MLMethod;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.factory.MLTrainFactory;
import org.encog.ml.train.MLTrain;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.training.propagation.scg.ScaledConjugateGradient;

/* loaded from: input_file:org/encog/ensemble/training/ScaledConjugateGradientFactory.class */
public class ScaledConjugateGradientFactory implements EnsembleTrainFactory {
    private double dropoutRate = 0.0d;

    @Override // org.encog.ensemble.EnsembleTrainFactory
    public MLTrain getTraining(MLMethod mLMethod, MLDataSet mLDataSet) {
        return getTraining(mLMethod, mLDataSet, this.dropoutRate);
    }

    @Override // org.encog.ensemble.EnsembleTrainFactory
    public MLTrain getTraining(MLMethod mLMethod, MLDataSet mLDataSet, double d) {
        ScaledConjugateGradient scaledConjugateGradient = new ScaledConjugateGradient((BasicNetwork) mLMethod, mLDataSet);
        scaledConjugateGradient.setDroupoutRate(d);
        return scaledConjugateGradient;
    }

    @Override // org.encog.ensemble.EnsembleTrainFactory
    public String getLabel() {
        String str = MLTrainFactory.TYPE_SCG;
        if (this.dropoutRate > 0.0d) {
            str = String.valueOf(str) + "-" + this.dropoutRate;
        }
        return str;
    }

    @Override // org.encog.ensemble.EnsembleTrainFactory
    public void setDropoutRate(double d) {
        this.dropoutRate = d;
    }
}
