/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.contrib.model.modelselection;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.logging.Level;
import org.neuroph.contrib.model.modelselection.NeurophModelOptimizer;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.eval.ClassifierEvaluator;
import org.neuroph.eval.CrossValidation;
import org.neuroph.eval.classification.ClassificationMetrics;
import org.neuroph.eval.classification.ConfusionMatrix;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultilayerPerceptronOptimazer<T extends BackPropagation>
implements NeurophModelOptimizer {
    private static Logger LOG = LoggerFactory.getLogger(MultilayerPerceptronOptimazer.class);
    private Set<List<Integer>> allArchitectures = new HashSet<List<Integer>>();
    private List<Integer> optimalArchitecure;
    private NeuralNetwork<BackPropagation> optimalClassifier;
    private ClassificationMetrics optimalResult;
    private CrossValidation errorEstimationMethod;
    private BackPropagation learningRule;
    private int maxLayers = 1;
    private int minNeuronsPerLayer = 1;
    private int maxNeuronsPerLayer = 30;
    private int neuronIncrement = 1;

    public MultilayerPerceptronOptimazer withMaxLayers(int maxLayers) {
        this.maxLayers = maxLayers;
        return this;
    }

    public MultilayerPerceptronOptimazer withNeuronIncrement(int neuronIncrement) {
        this.neuronIncrement = neuronIncrement;
        return this;
    }

    public MultilayerPerceptronOptimazer withMaxNeurons(int maxNeurons) {
        this.maxNeuronsPerLayer = maxNeurons;
        return this;
    }

    public MultilayerPerceptronOptimazer withMinNeurons(int minNeurons) {
        this.minNeuronsPerLayer = minNeurons;
        return this;
    }

    public MultilayerPerceptronOptimazer withErrorEstimationMethod(CrossValidation errorEstimationMethod) {
        this.errorEstimationMethod = errorEstimationMethod;
        return this;
    }

    public MultilayerPerceptronOptimazer withLearningRule(BackPropagation learningRule) {
        this.learningRule = learningRule;
        return this;
    }

    @Override
    public NeuralNetwork createOptimalModel(DataSet dataSet) {
        ArrayList<Integer> neurons = new ArrayList<Integer>();
        neurons.add(this.minNeuronsPerLayer);
        this.findArchitectures(1, this.minNeuronsPerLayer, neurons);
        LOG.info("Total [{}] different network topologies found", (Object)this.allArchitectures.size());
        for (List<Integer> architecture : this.allArchitectures) {
            try {
                architecture.add(0, dataSet.getInputSize());
                architecture.add(dataSet.getOutputSize());
                LOG.info("Architecture: [{}]", architecture);
                MultiLayerPerceptron network = new MultiLayerPerceptron(architecture);
                LearningListener listener = new LearningListener(10, this.learningRule.getMaxIterations());
                this.learningRule.addListener(listener);
                network.setLearningRule(this.learningRule);
                this.errorEstimationMethod = new CrossValidation(network, dataSet, 10);
                this.errorEstimationMethod.run();
                ClassificationMetrics[] result = ClassificationMetrics.createFromMatrix((ConfusionMatrix)this.errorEstimationMethod.getEvaluator(ClassifierEvaluator.MultiClass.class).getResult());
                if (this.optimalResult == null || this.optimalResult.getFMeasure() < result[0].getFMeasure()) {
                    LOG.info("Architecture [{}] became optimal architecture  with metrics {}", architecture, (Object)result);
                    this.optimalResult = result[0];
                    this.optimalClassifier = network;
                    this.optimalArchitecure = architecture;
                }
                LOG.info("#################################################################");
            }
            catch (InterruptedException ex) {
                java.util.logging.Logger.getLogger(MultilayerPerceptronOptimazer.class.getName()).log(Level.SEVERE, null, ex);
            }
            catch (ExecutionException ex) {
                java.util.logging.Logger.getLogger(MultilayerPerceptronOptimazer.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        LOG.info("Optimal Architecture: {}", this.optimalArchitecure);
        return this.optimalClassifier;
    }

    private void findArchitectures(int currentLayer, int lastLayerNeuronCount, List<Integer> nerons) {
        this.allArchitectures.add(new ArrayList<Integer>(nerons));
        if (lastLayerNeuronCount + this.neuronIncrement <= this.maxNeuronsPerLayer) {
            int indexOfLastElement = nerons.size() - 1;
            ArrayList<Integer> newList = new ArrayList<Integer>(nerons);
            newList.set(indexOfLastElement, lastLayerNeuronCount + this.neuronIncrement);
            this.findArchitectures(currentLayer, lastLayerNeuronCount + this.neuronIncrement, newList);
        }
        if (currentLayer + 1 <= this.maxLayers) {
            ArrayList<Integer> newList = new ArrayList<Integer>(nerons);
            newList.add(1);
            this.findArchitectures(currentLayer + 1, this.minNeuronsPerLayer, newList);
        }
    }

    static class LearningListener
    implements LearningEventListener {
        private double[] foldErrors;
        private int foldSize;

        public LearningListener(int foldSize, int maxIterations) {
            this.foldSize = foldSize;
            this.foldErrors = new double[maxIterations];
        }

        @Override
        public void handleLearningEvent(LearningEvent event) {
            BackPropagation bp = (BackPropagation)event.getSource();
            int n = bp.getCurrentIteration() - 1;
            this.foldErrors[n] = this.foldErrors[n] + bp.getTotalNetworkError() / (double)this.foldSize;
        }
    }
}

