/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.cpn.training;

import org.encog.mathutil.BoundMath;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.NeuralNetworkError;
import org.encog.neural.cpn.CPN;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.EngineArray;

public class TrainInstar
extends BasicTraining
implements LearningRate {
    private final CPN network;
    private final MLDataSet training;
    private double learningRate;
    private boolean mustInit;

    public TrainInstar(CPN theNetwork, MLDataSet theTraining, double theLearningRate, boolean theInitWeights) {
        super(TrainingImplementationType.Iterative);
        this.network = theNetwork;
        this.training = theTraining;
        this.learningRate = theLearningRate;
        this.mustInit = theInitWeights;
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    @Override
    public double getLearningRate() {
        return this.learningRate;
    }

    @Override
    public CPN getMethod() {
        return this.network;
    }

    private void initWeights() {
        if (this.training.getRecordCount() != (long)this.network.getInstarCount()) {
            throw new NeuralNetworkError("If the weights are to be set from the training data, then there must be one instar neuron for each training element.");
        }
        int i = 0;
        for (MLDataPair pair : this.training) {
            int j = 0;
            while (j < this.network.getInputCount()) {
                this.network.getWeightsInputToInstar().set(j, i, pair.getInput().getData(j));
                ++j;
            }
            ++i;
        }
        this.mustInit = false;
    }

    @Override
    public void iteration() {
        if (this.mustInit) {
            this.initWeights();
        }
        double worstDistance = Double.NEGATIVE_INFINITY;
        for (MLDataPair pair : this.training) {
            MLData out = this.network.computeInstar(pair.getInput());
            int winner = EngineArray.indexOfLargest(out.getData());
            double distance = 0.0;
            int i = 0;
            while (i < pair.getInput().size()) {
                double diff = pair.getInput().getData(i) - this.network.getWeightsInputToInstar().get(i, winner);
                distance += diff * diff;
                ++i;
            }
            if ((distance = BoundMath.sqrt(distance)) > worstDistance) {
                worstDistance = distance;
            }
            int j = 0;
            while (j < this.network.getInputCount()) {
                double delta = this.learningRate * (pair.getInput().getData(j) - this.network.getWeightsInputToInstar().get(j, winner));
                this.network.getWeightsInputToInstar().add(j, winner, delta);
                ++j;
            }
        }
        this.setError(worstDistance);
    }

    @Override
    public TrainingContinuation pause() {
        return null;
    }

    @Override
    public void resume(TrainingContinuation state) {
    }

    @Override
    public void setLearningRate(double rate) {
        this.learningRate = rate;
    }
}

