/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.classification.mlp;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.jdmp.core.algorithm.classification.AbstractClassifier;
import org.jdmp.core.algorithm.classification.Classifier;
import org.jdmp.core.algorithm.classification.mlp.NetworkLayer;
import org.jdmp.core.dataset.ListDataSet;
import org.jdmp.core.sample.Sample;
import org.jdmp.core.variable.Variable;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.listmatrix.DefaultListMatrix;

public class MultiLayerNetwork
extends AbstractClassifier {
    private static final long serialVersionUID = -6875411556504865444L;
    private Aggregation aggregationInput = null;
    private Transfer transferInput = null;
    private BiasType biasInput = null;
    private Aggregation aggregationDefault = null;
    private Transfer transferDefault = null;
    private BiasType biasDefault = null;
    private Aggregation aggregationOutput = null;
    private Transfer transferOutput = null;
    private BiasType biasOutput = null;
    private int[] neuronCount = null;
    private final List<NetworkLayer> networkLayers = new ArrayList<NetworkLayer>();

    public MultiLayerNetwork(Aggregation aggregationInput, Transfer transferInput, BiasType biasInput, Aggregation aggregationDefault, Transfer transferDefault, BiasType biasDefault, Aggregation aggregationOutput, Transfer transferOutput, BiasType biasOutput, int ... hiddenNeurons) {
        this.aggregationInput = aggregationInput;
        this.transferInput = transferInput;
        this.biasInput = biasInput;
        this.aggregationDefault = aggregationDefault;
        this.transferDefault = transferDefault;
        this.biasDefault = biasDefault;
        this.aggregationOutput = aggregationOutput;
        this.transferOutput = transferOutput;
        this.biasOutput = biasOutput;
        this.neuronCount = hiddenNeurons;
        NetworkLayer previousLayer = null;
        if (hiddenNeurons != null) {
            NetworkLayer hiddenLayer = null;
            for (int i = 0; i < hiddenNeurons.length; ++i) {
                hiddenLayer = i == 0 ? new NetworkLayer(aggregationInput, transferInput, biasInput, hiddenNeurons[i]) : new NetworkLayer(aggregationDefault, transferDefault, biasDefault, hiddenNeurons[i]);
                this.getAlgorithmMap().put("hidden" + i + "-forward", hiddenLayer.getAlgorithmForward());
                this.getAlgorithmMap().put("hidden" + i + "-backward", hiddenLayer.getAlgorithmBackward());
                this.getAlgorithmMap().put("hidden" + i + "-update", hiddenLayer.getAlgorithmWeightUpdate());
                hiddenLayer.setLayer(i);
                this.networkLayers.add(hiddenLayer);
                if (previousLayer != null) {
                    previousLayer.setNextLayer(hiddenLayer);
                }
                hiddenLayer.setPreviousLayer(previousLayer);
                previousLayer = hiddenLayer;
            }
        }
        NetworkLayer outputLayer = new NetworkLayer(aggregationOutput, transferOutput, biasOutput);
        this.getAlgorithmMap().put("output-forward", outputLayer.getAlgorithmForward());
        this.getAlgorithmMap().put("output-backward", outputLayer.getAlgorithmBackward());
        this.getAlgorithmMap().put("output-update", outputLayer.getAlgorithmWeightUpdate());
        outputLayer.setLayer(hiddenNeurons.length);
        this.networkLayers.add(outputLayer);
        if (previousLayer != null) {
            previousLayer.setNextLayer(outputLayer);
            outputLayer.setPreviousLayer(previousLayer);
        }
        Variable output = Variable.Factory.labeledVariable("Output");
        this.setOutputVariable(output);
        Variable outputDeviation = Variable.Factory.labeledVariable("Output Deviation");
        this.setOutputDeviationVariable(outputDeviation);
        Variable input = Variable.Factory.labeledVariable("Input");
        this.setInputVariable(input);
        Variable desiredOutput = Variable.Factory.labeledVariable("Desired Output");
        this.setDesiredOutputVariable(desiredOutput);
    }

    public MultiLayerNetwork(int ... neuronCount) {
        this(Aggregation.SUM, Transfer.TANH, BiasType.SINGLE, Aggregation.SUM, Transfer.TANH, BiasType.SINGLE, Aggregation.SUM, Transfer.TANH, BiasType.SINGLE, neuronCount);
    }

    public MultiLayerNetwork(Aggregation aggregation, Transfer transfer, int ... neuronCount) {
        this(aggregation, transfer, BiasType.SINGLE, aggregation, transfer, BiasType.SINGLE, aggregation, transfer, BiasType.SINGLE, neuronCount);
    }

    public MultiLayerNetwork(Aggregation aggregation, BiasType bias, int ... neuronCount) {
        this(aggregation, Transfer.TANH, bias, aggregation, Transfer.TANH, bias, aggregation, Transfer.TANH, bias, neuronCount);
    }

    public MultiLayerNetwork(Aggregation aggregation, int ... neuronCount) {
        this(aggregation, Transfer.TANH, BiasType.SINGLE, aggregation, Transfer.TANH, BiasType.SINGLE, aggregation, Transfer.TANH, BiasType.SINGLE, neuronCount);
    }

    public MultiLayerNetwork(Aggregation aggregation, Transfer transfer, BiasType biasType, int ... neurons) {
        this(aggregation, transfer, biasType, aggregation, transfer, biasType, aggregation, transfer, biasType, neurons);
    }

    public void setOutputVariable(Variable v) {
        this.getOutputErrorAlgorithm().setVariable("Source 1", v);
        this.getOutputLayer().setOutputVariable(v);
    }

    public void setOutputDeviationVariable(Variable v) {
        this.getOutputErrorAlgorithm().setVariable("Target", v);
        this.getOutputLayer().setOutputDeviationVariable(v);
    }

    public void setLearningRate(double v) {
        for (NetworkLayer networkLayer : this.networkLayers) {
            networkLayer.setLearningRate(v);
        }
    }

    public double getLearningRate() {
        return this.networkLayers.get(0).getLearningRate();
    }

    @Override
    public void reset() {
        for (NetworkLayer networkLayer : this.networkLayers) {
            networkLayer.reset();
        }
    }

    public NetworkLayer getOutputLayer() {
        return this.networkLayers.get(this.networkLayers.size() - 1);
    }

    public NetworkLayer getInputLayer() {
        return this.networkLayers.get(0);
    }

    public Variable getInputVariable() {
        return this.getInputLayer().getInputVariable();
    }

    public void setInputVariable(Variable v) {
        this.getInputLayer().setInputVariable(v);
    }

    public void setDesiredOutputVariable(Variable v) {
        this.getOutputErrorAlgorithm().setVariable("Source 2", v);
    }

    public void addInputMatrix(Matrix m) {
        this.getInputLayer().addInputMatrix(m);
    }

    public void setSampleWeight(double weight) {
        for (NetworkLayer networkLayer : this.networkLayers) {
            networkLayer.setSampleWeight(weight);
        }
    }

    public void addDesiredOutputMatrix(Matrix m) {
        DenseMatrix z;
        this.getOutputErrorAlgorithm().getVariableMap().setMatrix("Source 2", m);
        if (this.getOutputVariable().isEmpty()) {
            z = Matrix.Factory.zeros(m.getRowCount(), m.getColumnCount());
            this.getOutputVariable().add(z);
        }
        if (this.getOutputDeviationVariable().isEmpty()) {
            z = Matrix.Factory.zeros(m.getRowCount(), m.getColumnCount());
            this.getOutputDeviationVariable().add(z);
        }
    }

    public Variable getOutputVariable() {
        return this.getOutputLayer().getOutputVariable();
    }

    public Variable getOutputDeviationVariable() {
        return this.getOutputLayer().getOutputDeviationVariable();
    }

    public List<NetworkLayer> getNetworkLayerList() {
        return this.networkLayers;
    }

    @Override
    public Matrix predictOne(Matrix input) {
        this.addInputMatrix(input);
        for (NetworkLayer networkLayer : this.getNetworkLayerList()) {
            networkLayer.calculateForward();
        }
        Matrix actualOutput = this.getOutputMatrix().transpose();
        return actualOutput;
    }

    public Matrix getOutputMatrix() {
        return (Matrix)this.getOutputVariable().getLast();
    }

    @Override
    public void trainOne(Matrix input, Matrix sampleWeight, Matrix desiredOutput) {
        int i;
        this.addDesiredOutputMatrix(desiredOutput.toRowVector(Calculation.Ret.NEW));
        if (sampleWeight == null) {
            sampleWeight = Matrix.Factory.linkToValue(1.0);
        }
        this.setSampleWeight(sampleWeight.doubleValue());
        this.predictOne(input);
        this.getOutputErrorAlgorithm().calculate();
        for (i = this.networkLayers.size() - 1; i != -1; --i) {
            this.networkLayers.get(i).calculateBackward();
        }
        for (i = this.networkLayers.size() - 1; i != -1; --i) {
            this.networkLayers.get(i).calculateWeightUpdate();
        }
    }

    public int determineOptimalTrainingDuration(ListDataSet dataSet, int numberOfSteps) throws Exception {
        long seed = System.currentTimeMillis();
        DefaultListMatrix duration = new DefaultListMatrix();
        for (int r = 0; r < 10; ++r) {
            List<ListDataSet> dss = dataSet.splitForStratifiedCV(10, r, seed);
            ListDataSet train = dss.get(0);
            ListDataSet test = dss.get(1);
            MultiLayerNetwork a = new MultiLayerNetwork(this.aggregationInput, this.transferInput, this.biasInput, this.aggregationDefault, this.transferDefault, this.biasDefault, this.aggregationOutput, this.transferOutput, this.biasOutput, this.neuronCount);
            a.setLearningRate(this.getLearningRate());
            for (int i = 0; i < 10000; ++i) {
                a.trainAll(train);
                a.predictAll(test);
            }
        }
        int mean = (int)(duration.getMeanValue() * 0.9);
        if (mean == 0) {
            mean = 1;
        }
        return mean;
    }

    @Override
    public void trainAll(ListDataSet dataSet) {
        int i;
        ArrayList<Sample> samples = new ArrayList<Sample>(dataSet);
        Collections.shuffle(samples);
        int last10Percent = (int)Math.ceil((double)samples.size() * 0.1);
        int first90Percent = samples.size() - last10Percent;
        for (int i2 = 0; i2 < first90Percent; ++i2) {
            this.trainOne((Sample)samples.get(i2));
        }
        double rmse = 0.0;
        for (i = first90Percent; i < samples.size(); ++i) {
            Sample rs = (Sample)samples.get(i);
            Matrix output = this.predictOne(rs.getAsMatrix(this.getInputLabel()));
            rmse += output.minus(rs.getAsMatrix(this.getTargetLabel())).getRMS();
            this.trainOne((Sample)samples.get(i));
        }
        System.out.println("RMSE on " + last10Percent + " Samples: " + (rmse /= (double)last10Percent));
        for (i = first90Percent; i < samples.size(); ++i) {
            this.trainOne((Sample)samples.get(i));
        }
    }

    public void trainOnce(ListDataSet dataSet) throws Exception {
        ArrayList<Sample> samples = new ArrayList<Sample>(dataSet);
        Collections.shuffle(samples);
        for (Sample s : samples) {
            this.trainOne(s);
        }
        dataSet.fireValueChanged();
    }

    @Override
    public Classifier emptyCopy() {
        return null;
    }

    public static enum BiasType {
        NONE,
        SINGLE,
        MULTIPLE;

    }

    public static enum Transfer {
        TANH,
        TANHPLUSONE,
        SIGMOID,
        LINEAR,
        SIN,
        LOG,
        GAUSS;

    }

    public static enum Aggregation {
        MEAN,
        SUM;

    }
}

