/*
 * Decompiled with CFR 0.152.
 */
package jhpro.nnet.jknnl;

import java.util.Iterator;
import java.util.TreeMap;
import jhpro.nnet.jknnl.LearningDataModel;
import jhpro.nnet.jknnl.LearningFactorFunctionalModel;
import jhpro.nnet.jknnl.MetricModel;
import jhpro.nnet.jknnl.NeighbourhoodFunctionModel;
import jhpro.nnet.jknnl.NetworkModel;
import jhpro.nnet.jknnl.NeuronModel;
import jhpro.nnet.jknnl.TopologyModel;

public class WTMLearningFunction {
    protected MetricModel metrics;
    protected NetworkModel networkModel;
    protected int maxIteration;
    protected LearningDataModel learningData;
    protected LearningFactorFunctionalModel functionalModel;
    protected TopologyModel topology;
    protected NeighbourhoodFunctionModel neighboorhoodFunction;
    private boolean showComments = false;

    public WTMLearningFunction(NetworkModel networkModel, int maxIteration, MetricModel metrics, LearningDataModel learningData, LearningFactorFunctionalModel functionalModel, NeighbourhoodFunctionModel neighboorhoodFunction) {
        this.maxIteration = maxIteration;
        this.networkModel = networkModel;
        this.metrics = metrics;
        this.learningData = learningData;
        this.functionalModel = functionalModel;
        this.topology = networkModel.getTopology();
        this.neighboorhoodFunction = neighboorhoodFunction;
    }

    public boolean isShowComments() {
        return this.showComments;
    }

    public void setShowComments(boolean showComments) {
        this.showComments = showComments;
    }

    public void setNeighboorhoodFunction(NeighbourhoodFunctionModel neighboorhoodFunction) {
        this.neighboorhoodFunction = neighboorhoodFunction;
    }

    public NeighbourhoodFunctionModel getNeighboorhoodFunction() {
        return this.neighboorhoodFunction;
    }

    public MetricModel getMetrics() {
        return this.metrics;
    }

    public void setMetrics(MetricModel metrics) {
        this.metrics = metrics;
    }

    public void setNetworkModel(NetworkModel networkModel) {
        this.networkModel = networkModel;
    }

    public NetworkModel getNetworkModel() {
        return this.networkModel;
    }

    public void setMaxIteration(int maxIteration) {
        this.maxIteration = maxIteration;
    }

    public int getMaxIteration() {
        return this.maxIteration;
    }

    public void setLearningData(LearningDataModel learningData) {
        this.learningData = learningData;
    }

    public LearningDataModel getLearningData() {
        return this.learningData;
    }

    public void setFunctionalModel(LearningFactorFunctionalModel functionalModel) {
        this.functionalModel = functionalModel;
    }

    public LearningFactorFunctionalModel getFunctionalModel() {
        return this.functionalModel;
    }

    public int getBestNeuron(double[] vector) {
        double bestDistance = -1.0;
        int networkSize = this.networkModel.getNumbersOfNeurons();
        int bestNeuron = 0;
        for (int i = 0; i < networkSize; ++i) {
            double distance;
            NeuronModel tempNeuron = this.networkModel.getNeuron(i);
            if (tempNeuron == null || !((distance = this.metrics.getDistance(tempNeuron.getWeight(), vector)) < bestDistance) && bestDistance != -1.0) continue;
            bestDistance = distance;
            bestNeuron = i;
        }
        return bestNeuron;
    }

    protected void changeNeuronWeight(int neuronNumber, double[] vector, int iteration, int distance) {
        double[] weightList = this.networkModel.getNeuron(neuronNumber - 1).getWeight();
        int weightNumber = weightList.length;
        if (this.showComments) {
            String vectorText = "[";
            for (int i = 0; i < vector.length; ++i) {
                vectorText = vectorText + vector[i];
                if (i >= vector.length - 1) continue;
                vectorText = vectorText + ", ";
            }
            vectorText = vectorText + "]";
            System.out.println("Vector: " + vectorText);
            String weightText = "[";
            for (int i = 0; i < weightList.length; ++i) {
                weightText = weightText + weightList[i];
                if (i >= weightList.length - 1) continue;
                weightText = weightText + ", ";
            }
            weightText = weightText + "]";
            System.out.println("Neuron " + (neuronNumber + 1) + " weight before change: " + weightText);
        }
        for (int i = 0; i < weightNumber; ++i) {
            double weight = weightList[i];
            int n = i;
            weightList[n] = weightList[n] + this.functionalModel.getValue(iteration) * this.neighboorhoodFunction.getValue(distance) * (vector[i] - weight);
        }
        this.networkModel.getNeuron(neuronNumber).setWeight(weightList);
        if (this.showComments) {
            String weightText = "[";
            for (int i = 0; i < weightList.length; ++i) {
                weightText = weightText + weightList[i];
                if (i >= weightList.length - 1) continue;
                weightText = weightText + ", ";
            }
            weightText = weightText + "]";
            System.out.println("Neuron " + (neuronNumber + 1) + " weight after change: " + weightText);
        }
    }

    public void changeWeight(int neuronNumber, double[] vector, int iteration) {
        TreeMap neighboorhood = this.topology.getNeighbourhood(neuronNumber);
        Iterator it = neighboorhood.keySet().iterator();
        while (it.hasNext()) {
            int neuronNr = (Integer)it.next();
            this.changeNeuronWeight(neuronNr, vector, iteration, (Integer)neighboorhood.get(neuronNr));
        }
    }

    public void learn() {
        int bestNeuron = 0;
        int dataSize = this.learningData.getDataSize();
        for (int i = 0; i < this.maxIteration; ++i) {
            if (this.showComments) {
                System.out.println("Iteration number: " + (i + 1));
            }
            for (int j = 0; j < dataSize; ++j) {
                double[] vector = this.learningData.getData(j);
                bestNeuron = this.getBestNeuron(vector);
                if (this.showComments) {
                    System.out.println("Best neuron number: " + (bestNeuron + 1));
                }
                this.changeWeight(bestNeuron, vector, i);
            }
        }
    }
}

