/*
 * Decompiled with CFR 0.152.
 */
package edu.hitsz.c102c.cnn;

import edu.hitsz.c102c.cnn.Dataset;
import edu.hitsz.c102c.cnn.Layer;
import edu.hitsz.c102c.util.ConcurenceRunner;
import edu.hitsz.c102c.util.Log;
import edu.hitsz.c102c.util.Util;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

public class CNN
implements Serializable {
    private static final long serialVersionUID = 1L;
    private static double ALPHA = 0.85;
    protected static final double LAMBDA = 0.0;
    private List<Layer> layers;
    private int layerNum;
    private int batchSize;
    private Util.Operator divide_batchSize;
    private Util.Operator multiply_alpha;
    private Util.Operator multiply_lambda;
    private static AtomicBoolean stopTrain;

    public CNN(LayerBuilder layerBuilder, int batchSize) {
        this.layers = layerBuilder.mLayers;
        this.layerNum = this.layers.size();
        this.batchSize = batchSize;
        this.setup(batchSize);
        this.initPerator();
    }

    private void initPerator() {
        this.divide_batchSize = new Util.Operator(){
            private static final long serialVersionUID = 7424011281732651055L;

            @Override
            public double process(double value) {
                return value / (double)CNN.this.batchSize;
            }
        };
        this.multiply_alpha = new Util.Operator(){
            private static final long serialVersionUID = 5761368499808006552L;

            @Override
            public double process(double value) {
                return value * ALPHA;
            }
        };
        this.multiply_lambda = new Util.Operator(){
            private static final long serialVersionUID = 4499087728362870577L;

            @Override
            public double process(double value) {
                return value * (1.0 - 0.0 * ALPHA);
            }
        };
    }

    public void train(Dataset trainset, int repeat) {
        new Lisenter().start();
        for (int t = 0; t < repeat && !stopTrain.get(); ++t) {
            int epochsNum = trainset.size() / this.batchSize;
            if (trainset.size() % this.batchSize != 0) {
                ++epochsNum;
            }
            Log.i("");
            Log.i(t + "th iter epochsNum:" + epochsNum);
            int right = 0;
            int count = 0;
            for (int i = 0; i < epochsNum; ++i) {
                int[] randPerm = Util.randomPerm(trainset.size(), this.batchSize);
                Layer.prepareForNewBatch();
                for (int index : randPerm) {
                    boolean isRight = this.train(trainset.getRecord(index));
                    if (isRight) {
                        ++right;
                    }
                    ++count;
                    Layer.prepareForNewRecord();
                }
                this.updateParas();
                if (i % 50 != 0) continue;
                System.out.print("..");
                if (i + 50 <= epochsNum) continue;
                System.out.println();
            }
            double p = 1.0 * (double)right / (double)count;
            if (t % 10 == 1 && p > 0.96) {
                ALPHA = 0.001 + ALPHA * 0.9;
                Log.i("Set alpha = " + ALPHA);
            }
            Log.i("precision " + right + "/" + count + "=" + p);
        }
    }

    public double test(Dataset trainset) {
        Layer.prepareForNewBatch();
        Iterator<Dataset.Record> iter = trainset.iter();
        int right = 0;
        while (iter.hasNext()) {
            Dataset.Record record = iter.next();
            this.forward(record);
            Layer outputLayer = this.layers.get(this.layerNum - 1);
            int mapNum = outputLayer.getOutMapNum();
            double[] out = new double[mapNum];
            for (int m = 0; m < mapNum; ++m) {
                double[][] outmap = outputLayer.getMap(m);
                out[m] = outmap[0][0];
            }
            if (record.getLable().intValue() != Util.getMaxIndex(out)) continue;
            ++right;
        }
        double p = 1.0 * (double)right / (double)trainset.size();
        Log.i("precision", p + "");
        return p;
    }

    public void predict(Dataset testset, String fileName) {
        Log.i("begin predict");
        try {
            int max = this.layers.get(this.layerNum - 1).getClassNum();
            PrintWriter writer = new PrintWriter(new File(fileName));
            Layer.prepareForNewBatch();
            Iterator<Dataset.Record> iter = testset.iter();
            while (iter.hasNext()) {
                Dataset.Record record = iter.next();
                this.forward(record);
                Layer outputLayer = this.layers.get(this.layerNum - 1);
                int mapNum = outputLayer.getOutMapNum();
                double[] out = new double[mapNum];
                for (int m = 0; m < mapNum; ++m) {
                    double[][] outmap = outputLayer.getMap(m);
                    out[m] = outmap[0][0];
                }
                int lable = Util.getMaxIndex(out);
                writer.write(lable + "\n");
            }
            writer.flush();
            writer.close();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        Log.i("end predict");
    }

    private boolean isSame(double[] output, double[] target) {
        boolean r = true;
        for (int i = 0; i < output.length; ++i) {
            if (!(Math.abs(output[i] - target[i]) > 0.5)) continue;
            r = false;
            break;
        }
        return r;
    }

    private boolean train(Dataset.Record record) {
        this.forward(record);
        boolean result = this.backPropagation(record);
        return result;
    }

    private boolean backPropagation(Dataset.Record record) {
        boolean result = this.setOutLayerErrors(record);
        this.setHiddenLayerErrors();
        return result;
    }

    private void updateParas() {
        block3: for (int l = 1; l < this.layerNum; ++l) {
            Layer layer = this.layers.get(l);
            Layer lastLayer = this.layers.get(l - 1);
            switch (layer.getType()) {
                case conv: 
                case output: {
                    this.updateKernels(layer, lastLayer);
                    this.updateBias(layer, lastLayer);
                    continue block3;
                }
            }
        }
    }

    private void updateBias(final Layer layer, Layer lastLayer) {
        final double[][][][] errors = layer.getErrors();
        int mapNum = layer.getOutMapNum();
        new ConcurenceRunner.TaskManager(mapNum){

            @Override
            public void process(int start, int end) {
                for (int j = start; j < end; ++j) {
                    double[][] error = Util.sum(errors, j);
                    double deltaBias = Util.sum(error) / (double)CNN.this.batchSize;
                    double bias = layer.getBias(j) + ALPHA * deltaBias;
                    layer.setBias(j, bias);
                }
            }
        }.start();
    }

    private void updateKernels(final Layer layer, final Layer lastLayer) {
        int mapNum = layer.getOutMapNum();
        final int lastMapNum = lastLayer.getOutMapNum();
        new ConcurenceRunner.TaskManager(mapNum){

            @Override
            public void process(int start, int end) {
                for (int j = start; j < end; ++j) {
                    for (int i = 0; i < lastMapNum; ++i) {
                        double[][] deltaKernel = null;
                        for (int r = 0; r < CNN.this.batchSize; ++r) {
                            double[][] error = layer.getError(r, j);
                            deltaKernel = deltaKernel == null ? Util.convnValid(lastLayer.getMap(r, i), error) : Util.matrixOp(Util.convnValid(lastLayer.getMap(r, i), error), deltaKernel, null, null, Util.plus);
                        }
                        deltaKernel = Util.matrixOp(deltaKernel, CNN.this.divide_batchSize);
                        double[][] kernel = layer.getKernel(i, j);
                        deltaKernel = Util.matrixOp(kernel, deltaKernel, CNN.this.multiply_lambda, CNN.this.multiply_alpha, Util.plus);
                        layer.setKernel(i, j, deltaKernel);
                    }
                }
            }
        }.start();
    }

    private void setHiddenLayerErrors() {
        block4: for (int l = this.layerNum - 2; l > 0; --l) {
            Layer layer = this.layers.get(l);
            Layer nextLayer = this.layers.get(l + 1);
            switch (layer.getType()) {
                case samp: {
                    this.setSampErrors(layer, nextLayer);
                    continue block4;
                }
                case conv: {
                    this.setConvErrors(layer, nextLayer);
                    continue block4;
                }
            }
        }
    }

    private void setSampErrors(final Layer layer, final Layer nextLayer) {
        int mapNum = layer.getOutMapNum();
        final int nextMapNum = nextLayer.getOutMapNum();
        new ConcurenceRunner.TaskManager(mapNum){

            @Override
            public void process(int start, int end) {
                for (int i = start; i < end; ++i) {
                    double[][] sum = null;
                    for (int j = 0; j < nextMapNum; ++j) {
                        double[][] nextError = nextLayer.getError(j);
                        double[][] kernel = nextLayer.getKernel(i, j);
                        sum = sum == null ? Util.convnFull(nextError, Util.rot180(kernel)) : Util.matrixOp(Util.convnFull(nextError, Util.rot180(kernel)), sum, null, null, Util.plus);
                    }
                    layer.setError(i, sum);
                }
            }
        }.start();
    }

    private void setConvErrors(final Layer layer, final Layer nextLayer) {
        int mapNum = layer.getOutMapNum();
        new ConcurenceRunner.TaskManager(mapNum){

            @Override
            public void process(int start, int end) {
                for (int m = start; m < end; ++m) {
                    Layer.Size scale = nextLayer.getScaleSize();
                    double[][] nextError = nextLayer.getError(m);
                    double[][] map = layer.getMap(m);
                    double[][] outMatrix = Util.matrixOp(map, Util.cloneMatrix(map), null, Util.one_value, Util.multiply);
                    outMatrix = Util.matrixOp(outMatrix, Util.kronecker(nextError, scale), null, null, Util.multiply);
                    layer.setError(m, outMatrix);
                }
            }
        }.start();
    }

    private boolean setOutLayerErrors(Dataset.Record record) {
        Layer outputLayer = this.layers.get(this.layerNum - 1);
        int mapNum = outputLayer.getOutMapNum();
        double[] target = new double[mapNum];
        double[] outmaps = new double[mapNum];
        for (int m = 0; m < mapNum; ++m) {
            double[][] outmap = outputLayer.getMap(m);
            outmaps[m] = outmap[0][0];
        }
        int lable = record.getLable().intValue();
        target[lable] = 1.0;
        for (int m = 0; m < mapNum; ++m) {
            outputLayer.setError(m, 0, 0, outmaps[m] * (1.0 - outmaps[m]) * (target[m] - outmaps[m]));
        }
        return lable == Util.getMaxIndex(outmaps);
    }

    private void forward(Dataset.Record record) {
        this.setInLayerOutput(record);
        block5: for (int l = 1; l < this.layers.size(); ++l) {
            Layer layer = this.layers.get(l);
            Layer lastLayer = this.layers.get(l - 1);
            switch (layer.getType()) {
                case conv: {
                    this.setConvOutput(layer, lastLayer);
                    continue block5;
                }
                case samp: {
                    this.setSampOutput(layer, lastLayer);
                    continue block5;
                }
                case output: {
                    this.setConvOutput(layer, lastLayer);
                    continue block5;
                }
            }
        }
    }

    private void setInLayerOutput(Dataset.Record record) {
        Layer inputLayer = this.layers.get(0);
        Layer.Size mapSize = inputLayer.getMapSize();
        double[] attr = record.getAttrs();
        if (attr.length != mapSize.x * mapSize.y) {
            throw new RuntimeException("The size of the data record does not match the defined map size");
        }
        for (int i = 0; i < mapSize.x; ++i) {
            for (int j = 0; j < mapSize.y; ++j) {
                inputLayer.setMapValue(0, i, j, attr[mapSize.x * i + j]);
            }
        }
    }

    private void setConvOutput(final Layer layer, final Layer lastLayer) {
        int mapNum = layer.getOutMapNum();
        final int lastMapNum = lastLayer.getOutMapNum();
        new ConcurenceRunner.TaskManager(mapNum){

            @Override
            public void process(int start, int end) {
                for (int j = start; j < end; ++j) {
                    double[][] sum = null;
                    for (int i = 0; i < lastMapNum; ++i) {
                        double[][] lastMap = lastLayer.getMap(i);
                        double[][] kernel = layer.getKernel(i, j);
                        sum = sum == null ? Util.convnValid(lastMap, kernel) : Util.matrixOp(Util.convnValid(lastMap, kernel), sum, null, null, Util.plus);
                    }
                    final double bias = layer.getBias(j);
                    sum = Util.matrixOp(sum, new Util.Operator(){
                        private static final long serialVersionUID = 2469461972825890810L;

                        @Override
                        public double process(double value) {
                            return Util.sigmod(value + bias);
                        }
                    });
                    layer.setMapValue(j, sum);
                }
            }
        }.start();
    }

    private void setSampOutput(final Layer layer, final Layer lastLayer) {
        int lastMapNum = lastLayer.getOutMapNum();
        new ConcurenceRunner.TaskManager(lastMapNum){

            @Override
            public void process(int start, int end) {
                for (int i = start; i < end; ++i) {
                    double[][] lastMap = lastLayer.getMap(i);
                    Layer.Size scaleSize = layer.getScaleSize();
                    double[][] sampMatrix = Util.scaleMatrix(lastMap, scaleSize);
                    layer.setMapValue(i, sampMatrix);
                }
            }
        }.start();
    }

    public void setup(int batchSize) {
        Layer inputLayer = this.layers.get(0);
        inputLayer.initOutmaps(batchSize);
        block6: for (int i = 1; i < this.layers.size(); ++i) {
            Layer layer = this.layers.get(i);
            Layer frontLayer = this.layers.get(i - 1);
            int frontMapNum = frontLayer.getOutMapNum();
            switch (layer.getType()) {
                case input: {
                    continue block6;
                }
                case conv: {
                    layer.setMapSize(frontLayer.getMapSize().subtract(layer.getKernelSize(), 1));
                    layer.initKernel(frontMapNum);
                    layer.initBias(frontMapNum);
                    layer.initErros(batchSize);
                    layer.initOutmaps(batchSize);
                    continue block6;
                }
                case samp: {
                    layer.setOutMapNum(frontMapNum);
                    layer.setMapSize(frontLayer.getMapSize().divide(layer.getScaleSize()));
                    layer.initErros(batchSize);
                    layer.initOutmaps(batchSize);
                    continue block6;
                }
                case output: {
                    layer.initOutputKerkel(frontMapNum, frontLayer.getMapSize());
                    layer.initBias(frontMapNum);
                    layer.initErros(batchSize);
                    layer.initOutmaps(batchSize);
                }
            }
        }
    }

    public void saveModel(String fileName) {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(fileName));
            oos.writeObject(this);
            oos.flush();
            oos.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static CNN loadModel(String fileName) {
        try {
            ObjectInputStream in = new ObjectInputStream(new FileInputStream(fileName));
            CNN cnn = (CNN)in.readObject();
            in.close();
            return cnn;
        }
        catch (IOException | ClassNotFoundException e) {
            e.printStackTrace();
            return null;
        }
    }

    public static class LayerBuilder {
        private List<Layer> mLayers = new ArrayList<Layer>();

        public LayerBuilder() {
        }

        public LayerBuilder(Layer layer) {
            this();
            this.mLayers.add(layer);
        }

        public LayerBuilder addLayer(Layer layer) {
            this.mLayers.add(layer);
            return this;
        }
    }

    static class Lisenter
    extends Thread {
        Lisenter() {
            this.setDaemon(true);
            stopTrain = new AtomicBoolean(false);
        }

        @Override
        public void run() {
            System.out.println("Input & to stop train.");
            while (true) {
                try {
                    int a;
                    while ((a = System.in.read()) != 38) {
                    }
                    stopTrain.compareAndSet(false, true);
                }
                catch (IOException e) {
                    e.printStackTrace();
                    continue;
                }
                break;
            }
            System.out.println("Lisenter stop");
        }
    }
}

