/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.nnet.learning;

import java.util.Iterator;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.learning.SupervisedLearning;
import org.neuroph.util.NeuralNetworkCODEC;

public class SimulatedAnnealingLearning
extends SupervisedLearning {
    private static final long serialVersionUID = 1L;
    private double startTemperature;
    private double stopTemperature;
    private int cycles;
    protected double temperature;
    private double[] weights;
    private double[] bestWeights;

    public SimulatedAnnealingLearning(NeuralNetwork network, double startTemp, double stopTemp, int cycles) {
        this.setNeuralNetwork(network);
        this.temperature = startTemp;
        this.startTemperature = startTemp;
        this.stopTemperature = stopTemp;
        this.cycles = cycles;
        this.weights = new double[NeuralNetworkCODEC.determineArraySize(network)];
        this.bestWeights = new double[NeuralNetworkCODEC.determineArraySize(network)];
        NeuralNetworkCODEC.network2array(network, this.weights);
        NeuralNetworkCODEC.network2array(network, this.bestWeights);
    }

    public SimulatedAnnealingLearning(NeuralNetwork network) {
        this(network, 10.0, 2.0, 1000);
    }

    public NeuralNetwork getNetwork() {
        return this.getNeuralNetwork();
    }

    public void randomize(double randomChance) {
        for (int i = 0; i < this.weights.length; ++i) {
            if (!(Math.random() < randomChance)) continue;
            double add = 0.5 - Math.random();
            add /= this.startTemperature;
            this.weights[i] = this.weights[i] + (add *= this.temperature);
        }
        NeuralNetworkCODEC.array2network(this.weights, this.getNetwork());
    }

    private double determineError(DataSet trainingSet) {
        double result = 0.0;
        Iterator<DataSetRow> iterator = trainingSet.iterator();
        while (iterator.hasNext() && !this.isStopped()) {
            DataSetRow trainingSetRow = iterator.next();
            double[] input = trainingSetRow.getInput();
            this.getNetwork().setInput(input);
            this.getNetwork().calculate();
            double[] output = this.getNetwork().getOutput();
            double[] desiredOutput = trainingSetRow.getDesiredOutput();
            double[] patternError = this.getErrorFunction().addPatternError(desiredOutput, output);
            double sqrErrorSum = 0.0;
            for (double error : patternError) {
                sqrErrorSum += error * error;
            }
            result += sqrErrorSum / (double)(2 * patternError.length);
        }
        return result;
    }

    @Override
    public void doLearningEpoch(DataSet trainingSet) {
        this.doLearningEpoch(trainingSet, 0.5);
    }

    public void doLearningEpoch(DataSet trainingSet, double randomChance) {
        System.arraycopy(this.weights, 0, this.bestWeights, 0, this.weights.length);
        double bestError = this.determineError(trainingSet);
        this.temperature = this.startTemperature;
        for (int i = 0; i < this.cycles; ++i) {
            this.randomize(randomChance);
            double currentError = this.determineError(trainingSet);
            if (currentError < bestError) {
                System.arraycopy(this.weights, 0, this.bestWeights, 0, this.weights.length);
                bestError = currentError;
            } else {
                System.arraycopy(this.bestWeights, 0, this.weights, 0, this.weights.length);
            }
            NeuralNetworkCODEC.array2network(this.bestWeights, this.getNetwork());
            double ratio = Math.exp(Math.log(this.stopTemperature / this.startTemperature) / (double)(this.cycles - 1));
            this.temperature *= ratio;
        }
        this.previousEpochError = this.getErrorFunction().getTotalError();
        if (this.hasReachedStopCondition()) {
            this.stopLearning();
        }
    }

    @Override
    protected void calculateWeightChanges(double[] patternError) {
    }
}

