/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.strategy;

import org.encog.EncogError;
import org.encog.ml.MLEncodable;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.Strategy;
import org.encog.util.EngineArray;

public class RegularizationStrategy
implements Strategy {
    private double lambda;
    private MLTrain train;
    private double[] weights;
    private double[] newWeights;
    private MLEncodable encodable;

    public RegularizationStrategy(double lambda) {
        this.lambda = lambda;
    }

    @Override
    public void init(MLTrain train) {
        this.train = train;
        if (!(train.getMethod() instanceof MLEncodable)) {
            throw new EncogError("Method must implement MLEncodable to be used with regularization.");
        }
        this.encodable = (MLEncodable)train.getMethod();
        this.weights = new double[this.encodable.encodedArrayLength()];
        this.newWeights = new double[this.encodable.encodedArrayLength()];
    }

    @Override
    public void preIteration() {
        ((MLEncodable)this.train.getMethod()).encodeToArray(this.weights);
    }

    @Override
    public void postIteration() {
        this.encodable.encodeToArray(this.newWeights);
        int i = 0;
        while (i < this.newWeights.length) {
            int n = i;
            this.newWeights[n] = this.newWeights[n] - this.lambda * this.weights[i];
            ++i;
        }
        this.encodable.decodeFromArray(this.newWeights);
        EngineArray.arrayCopy(this.newWeights, this.weights);
    }
}

