/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.lma;

import org.encog.mathutil.error.ErrorCalculation;
import org.encog.mathutil.matrices.decomposition.LUDecomposition;
import org.encog.mathutil.matrices.hessian.ComputeHessian;
import org.encog.mathutil.matrices.hessian.HessianCR;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.structure.NetworkCODEC;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.concurrency.MultiThreadable;
import org.encog.util.validate.ValidateNetwork;

public class LevenbergMarquardtTraining
extends BasicTraining
implements MultiThreadable {
    public static final double SCALE_LAMBDA = 10.0;
    public static final double LAMBDA_MAX = 1.0E25;
    private ComputeHessian hessian;
    private final BasicNetwork network;
    private final MLDataSet indexableTraining;
    private final int trainingLength;
    private final int weightCount;
    private double[] weights;
    private double lambda;
    private final double[] diagonal;
    private double[] deltas;
    private final MLDataPair pair;
    private boolean initComplete;

    public LevenbergMarquardtTraining(BasicNetwork network, MLDataSet training) {
        this(network, training, new HessianCR());
    }

    public LevenbergMarquardtTraining(BasicNetwork network, MLDataSet training, ComputeHessian h) {
        super(TrainingImplementationType.Iterative);
        ValidateNetwork.validateMethodToData(network, training);
        this.setTraining(training);
        this.indexableTraining = this.getTraining();
        this.network = network;
        this.trainingLength = (int)this.indexableTraining.getRecordCount();
        this.weightCount = this.network.getStructure().calculateSize();
        this.lambda = 0.1;
        this.deltas = new double[this.weightCount];
        this.diagonal = new double[this.weightCount];
        BasicMLData input = new BasicMLData(this.indexableTraining.getInputSize());
        BasicMLData ideal = new BasicMLData(this.indexableTraining.getIdealSize());
        this.pair = new BasicMLDataPair(input, ideal);
        this.hessian = h;
    }

    private void saveDiagonal() {
        double[][] h = this.hessian.getHessian();
        int i = 0;
        while (i < this.weightCount) {
            this.diagonal[i] = h[i][i];
            ++i;
        }
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    @Override
    public MLMethod getMethod() {
        return this.network;
    }

    private double calculateError() {
        ErrorCalculation result = new ErrorCalculation();
        int i = 0;
        while (i < this.trainingLength) {
            this.indexableTraining.getRecord(i, this.pair);
            MLData actual = this.network.compute(this.pair.getInput());
            result.updateError(actual.getData(), this.pair.getIdeal().getData(), this.pair.getSignificance());
            ++i;
        }
        return result.calculateESS();
    }

    private void applyLambda() {
        double[][] h = this.hessian.getHessian();
        int i = 0;
        while (i < this.weightCount) {
            h[i][i] = this.diagonal[i] + this.lambda;
            ++i;
        }
    }

    @Override
    public void iteration() {
        if (!this.initComplete) {
            this.hessian.init(this.network, this.getTraining());
            this.initComplete = true;
        }
        LUDecomposition decomposition = null;
        this.preIteration();
        this.hessian.clear();
        this.weights = NetworkCODEC.networkToArray(this.network);
        this.hessian.compute();
        double currentError = this.hessian.getSSE();
        this.saveDiagonal();
        double startingError = currentError;
        boolean done = false;
        while (!done) {
            this.applyLambda();
            decomposition = new LUDecomposition(this.hessian.getHessianMatrix());
            boolean singular = decomposition.isNonsingular();
            if (singular) {
                this.deltas = decomposition.Solve(this.hessian.getGradients());
                this.updateWeights();
                currentError = this.calculateError();
            }
            if (!singular || currentError >= startingError) {
                this.lambda *= 10.0;
                if (!(this.lambda > 1.0E25)) continue;
                this.lambda = 1.0E25;
                done = true;
                continue;
            }
            this.lambda /= 10.0;
            done = true;
        }
        this.setError(currentError);
        this.postIteration();
    }

    @Override
    public TrainingContinuation pause() {
        return null;
    }

    @Override
    public void resume(TrainingContinuation state) {
    }

    public void updateWeights() {
        double[] w = (double[])this.weights.clone();
        int i = 0;
        while (i < w.length) {
            int n = i;
            w[n] = w[n] + this.deltas[i];
            ++i;
        }
        NetworkCODEC.arrayToNetwork(w, this.network);
    }

    public ComputeHessian getHessian() {
        return this.hessian;
    }

    @Override
    public int getThreadCount() {
        if (this.hessian instanceof MultiThreadable) {
            return ((MultiThreadable)((Object)this.hessian)).getThreadCount();
        }
        return 1;
    }

    @Override
    public void setThreadCount(int numThreads) {
        if (this.hessian instanceof MultiThreadable) {
            ((MultiThreadable)((Object)this.hessian)).setThreadCount(numThreads);
        } else if (numThreads != 1 && numThreads != 0) {
            throw new TrainingError("The Hessian object in use(" + this.hessian.getClass().toString() + ") does not support multi-threaded mode.");
        }
    }
}

