/*
 * Decompiled with CFR 0.152.
 */
package org.joone.samples.engine.xor.rbf;

import org.joone.engine.BiasedLinearLayer;
import org.joone.engine.FullSynapse;
import org.joone.engine.LinearLayer;
import org.joone.engine.Monitor;
import org.joone.engine.NeuralNetEvent;
import org.joone.engine.NeuralNetListener;
import org.joone.engine.RbfGaussianLayer;
import org.joone.engine.RbfGaussianParameters;
import org.joone.engine.RbfInputSynapse;
import org.joone.engine.learning.TeachingSynapse;
import org.joone.io.MemoryInputSynapse;
import org.joone.io.MemoryOutputSynapse;
import org.joone.net.NeuralNet;

public class XOR_static_RBF
implements NeuralNetListener {
    private NeuralNet nnet = null;
    RbfGaussianLayer hidden = null;
    private MemoryInputSynapse inputSynapse;
    private MemoryInputSynapse desiredOutputSynapse;
    private MemoryOutputSynapse outputSynapse;
    private boolean randomCenters = false;
    private double[][] inputArray = new double[][]{{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}};
    private double[][] desiredOutputArray = new double[][]{{1.0}, {0.0}, {0.0}, {1.0}};

    public static void main(String[] args) {
        XOR_static_RBF xor = new XOR_static_RBF();
        xor.initNeuralNet();
        xor.train();
        xor.test();
    }

    public void train() {
        this.inputSynapse.setInputArray(this.inputArray);
        this.inputSynapse.setAdvancedColumnSelector("1,2");
        this.desiredOutputSynapse.setInputArray(this.desiredOutputArray);
        this.desiredOutputSynapse.setAdvancedColumnSelector("1");
        Monitor monitor = this.nnet.getMonitor();
        monitor.setLearningRate(0.3);
        monitor.setMomentum(0.8);
        monitor.setTrainingPatterns(this.inputArray.length);
        monitor.setTotCicles(200);
        monitor.setLearning(true);
        this.nnet.addNeuralNetListener(this);
        this.nnet.go(true);
    }

    protected void initNeuralNet() {
        LinearLayer input = new LinearLayer();
        this.hidden = new RbfGaussianLayer();
        BiasedLinearLayer output = new BiasedLinearLayer();
        input.setRows(2);
        this.hidden.setRows(2);
        output.setRows(1);
        if (!this.randomCenters) {
            RbfGaussianParameters[] myParameters = new RbfGaussianParameters[2];
            double[] myMean0 = new double[]{0.0, 0.0};
            myParameters[0] = new RbfGaussianParameters(myMean0, Math.sqrt(0.5));
            double[] myMean1 = new double[]{1.0, 1.0};
            myParameters[1] = new RbfGaussianParameters(myMean1, Math.sqrt(0.5));
            this.hidden.setGaussianParameters(myParameters);
        }
        RbfInputSynapse synapse_IH = new RbfInputSynapse();
        FullSynapse synapse_HO = new FullSynapse();
        input.addOutputSynapse(synapse_IH);
        this.hidden.addInputSynapse(synapse_IH);
        this.hidden.addOutputSynapse(synapse_HO);
        output.addInputSynapse(synapse_HO);
        this.inputSynapse = new MemoryInputSynapse();
        input.addInputSynapse(this.inputSynapse);
        if (this.randomCenters) {
            this.hidden.useRandomCenter(this.inputSynapse);
        }
        this.desiredOutputSynapse = new MemoryInputSynapse();
        TeachingSynapse trainer = new TeachingSynapse();
        trainer.setDesired(this.desiredOutputSynapse);
        this.nnet = new NeuralNet();
        this.nnet.addLayer(input, 0);
        this.nnet.addLayer(this.hidden, 1);
        this.nnet.addLayer(output, 2);
        this.nnet.setTeacher(trainer);
        output.addOutputSynapse(trainer);
    }

    public void test() {
        this.outputSynapse = new MemoryOutputSynapse();
        this.nnet.getOutputLayer().addOutputSynapse(this.outputSynapse);
        this.nnet.getMonitor().setTotCicles(1);
        this.nnet.getMonitor().setTrainingPatterns(4);
        this.nnet.getMonitor().setLearning(false);
        this.nnet.removeAllListeners();
        this.nnet.go();
        System.out.println("Outputs");
        System.out.println("-------");
        for (int i = 0; i < 4; ++i) {
            double[] myPattern = this.outputSynapse.getNextPattern();
            System.out.println("Output: " + myPattern[0]);
        }
        System.out.println("Centers RBF neurons: ");
        RbfGaussianParameters[] myParams = this.hidden.getGaussianParameters();
        for (int i = 0; i < myParams.length; ++i) {
            String myText = i + 1 + ": [center: ";
            for (int j = 0; j < myParams[i].getMean().length; ++j) {
                myText = myText + myParams[i].getMean()[j] + ", ";
            }
            myText = myText + "Std dev: " + myParams[i].getStdDeviation() + "]";
            System.out.println(myText);
        }
    }

    @Override
    public void cicleTerminated(NeuralNetEvent e) {
    }

    @Override
    public void errorChanged(NeuralNetEvent e) {
        Monitor mon = (Monitor)e.getSource();
        if (mon.getCurrentCicle() % 100 == 0) {
            System.out.println("Epoch: " + (mon.getTotCicles() - mon.getCurrentCicle()) + " RMSE:" + mon.getGlobalError());
        }
    }

    @Override
    public void netStarted(NeuralNetEvent e) {
    }

    @Override
    public void netStopped(NeuralNetEvent e) {
    }

    @Override
    public void netStoppedError(NeuralNetEvent e, String error) {
    }
}

