package edu.hitsz.c102c.cnn;

import edu.hitsz.c102c.cnn.Dataset;
import edu.hitsz.c102c.cnn.Layer;
import edu.hitsz.c102c.util.ConcurenceRunner;
import edu.hitsz.c102c.util.Log;
import edu.hitsz.c102c.util.Util;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

/* loaded from: input_file:edu/hitsz/c102c/cnn/CNN.class */
public class CNN implements Serializable {
    private static final long serialVersionUID = 1;
    private static double ALPHA = 0.85d;
    protected static final double LAMBDA = 0.0d;
    private List<Layer> layers;
    private int layerNum;
    private int batchSize;
    private Util.Operator divide_batchSize;
    private Util.Operator multiply_alpha;
    private Util.Operator multiply_lambda;
    private static AtomicBoolean stopTrain;

    /* loaded from: input_file:edu/hitsz/c102c/cnn/CNN$LayerBuilder.class */
    public static class LayerBuilder {
        private List<Layer> mLayers;

        public LayerBuilder() {
            this.mLayers = new ArrayList();
        }

        public LayerBuilder(Layer layer) {
            this();
            this.mLayers.add(layer);
        }

        public LayerBuilder addLayer(Layer layer) {
            this.mLayers.add(layer);
            return this;
        }
    }

    /* loaded from: input_file:edu/hitsz/c102c/cnn/CNN$Lisenter.class */
    static class Lisenter extends Thread {
        Lisenter() {
            setDaemon(true);
            AtomicBoolean unused = CNN.stopTrain = new AtomicBoolean(false);
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            System.out.println("Input & to stop train.");
            while (true) {
                try {
                } catch (IOException e) {
                    e.printStackTrace();
                }
                if (System.in.read() == 38) {
                    CNN.stopTrain.compareAndSet(false, true);
                    System.out.println("Lisenter stop");
                    return;
                }
                continue;
            }
        }
    }

    public CNN(LayerBuilder layerBuilder, int i) {
        this.layers = layerBuilder.mLayers;
        this.layerNum = this.layers.size();
        this.batchSize = i;
        setup(i);
        initPerator();
    }

    private void initPerator() {
        this.divide_batchSize = new Util.Operator() { // from class: edu.hitsz.c102c.cnn.CNN.1
            private static final long serialVersionUID = 7424011281732651055L;

            @Override // edu.hitsz.c102c.util.Util.Operator
            public double process(double d) {
                return d / CNN.this.batchSize;
            }
        };
        this.multiply_alpha = new Util.Operator() { // from class: edu.hitsz.c102c.cnn.CNN.2
            private static final long serialVersionUID = 5761368499808006552L;

            @Override // edu.hitsz.c102c.util.Util.Operator
            public double process(double d) {
                return d * CNN.ALPHA;
            }
        };
        this.multiply_lambda = new Util.Operator() { // from class: edu.hitsz.c102c.cnn.CNN.3
            private static final long serialVersionUID = 4499087728362870577L;

            @Override // edu.hitsz.c102c.util.Util.Operator
            public double process(double d) {
                return d * (1.0d - (CNN.LAMBDA * CNN.ALPHA));
            }
        };
    }

    public void train(Dataset dataset, int i) {
        new Lisenter().start();
        for (int i2 = 0; i2 < i && !stopTrain.get(); i2++) {
            int size = dataset.size() / this.batchSize;
            if (dataset.size() % this.batchSize != 0) {
                size++;
            }
            Log.i("");
            Log.i(i2 + "th iter epochsNum:" + size);
            int i3 = 0;
            int i4 = 0;
            for (int i5 = 0; i5 < size; i5++) {
                int[] randomPerm = Util.randomPerm(dataset.size(), this.batchSize);
                Layer.prepareForNewBatch();
                for (int i6 : randomPerm) {
                    if (train(dataset.getRecord(i6))) {
                        i3++;
                    }
                    i4++;
                    Layer.prepareForNewRecord();
                }
                updateParas();
                if (i5 % 50 == 0) {
                    System.out.print("..");
                    if (i5 + 50 > size) {
                        System.out.println();
                    }
                }
            }
            double d = (1.0d * i3) / i4;
            if (i2 % 10 == 1 && d > 0.96d) {
                ALPHA = 0.001d + (ALPHA * 0.9d);
                Log.i("Set alpha = " + ALPHA);
            }
            Log.i("precision " + i3 + "/" + i4 + "=" + d);
        }
    }

    public double test(Dataset dataset) {
        Layer.prepareForNewBatch();
        Iterator<Dataset.Record> iter = dataset.iter();
        int i = 0;
        while (iter.hasNext()) {
            Dataset.Record next = iter.next();
            forward(next);
            Layer layer = this.layers.get(this.layerNum - 1);
            int outMapNum = layer.getOutMapNum();
            double[] dArr = new double[outMapNum];
            for (int i2 = 0; i2 < outMapNum; i2++) {
                dArr[i2] = layer.getMap(i2)[0][0];
            }
            if (next.getLable().intValue() == Util.getMaxIndex(dArr)) {
                i++;
            }
        }
        double size = (1.0d * i) / dataset.size();
        Log.i("precision", size + "");
        return size;
    }

    public void predict(Dataset dataset, String str) {
        Log.i("begin predict");
        try {
            this.layers.get(this.layerNum - 1).getClassNum();
            PrintWriter printWriter = new PrintWriter(new File(str));
            Layer.prepareForNewBatch();
            Iterator<Dataset.Record> iter = dataset.iter();
            while (iter.hasNext()) {
                forward(iter.next());
                Layer layer = this.layers.get(this.layerNum - 1);
                int outMapNum = layer.getOutMapNum();
                double[] dArr = new double[outMapNum];
                for (int i = 0; i < outMapNum; i++) {
                    dArr[i] = layer.getMap(i)[0][0];
                }
                printWriter.write(Util.getMaxIndex(dArr) + "\n");
            }
            printWriter.flush();
            printWriter.close();
            Log.i("end predict");
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private boolean isSame(double[] dArr, double[] dArr2) {
        boolean z = true;
        int i = 0;
        while (true) {
            if (i >= dArr.length) {
                break;
            }
            if (Math.abs(dArr[i] - dArr2[i]) > 0.5d) {
                z = false;
                break;
            }
            i++;
        }
        return z;
    }

    private boolean train(Dataset.Record record) {
        forward(record);
        return backPropagation(record);
    }

    private boolean backPropagation(Dataset.Record record) {
        boolean outLayerErrors = setOutLayerErrors(record);
        setHiddenLayerErrors();
        return outLayerErrors;
    }

    private void updateParas() {
        for (int i = 1; i < this.layerNum; i++) {
            Layer layer = this.layers.get(i);
            Layer layer2 = this.layers.get(i - 1);
            switch (layer.getType()) {
                case conv:
                case output:
                    updateKernels(layer, layer2);
                    updateBias(layer, layer2);
                    break;
            }
        }
    }

    private void updateBias(final Layer layer, Layer layer2) {
        final double[][][][] errors = layer.getErrors();
        new ConcurenceRunner.TaskManager(layer.getOutMapNum()) { // from class: edu.hitsz.c102c.cnn.CNN.4
            @Override // edu.hitsz.c102c.util.ConcurenceRunner.TaskManager
            public void process(int i, int i2) {
                for (int i3 = i; i3 < i2; i3++) {
                    layer.setBias(i3, layer.getBias(i3) + (CNN.ALPHA * (Util.sum(Util.sum(errors, i3)) / CNN.this.batchSize)));
                }
            }
        }.start();
    }

    private void updateKernels(final Layer layer, final Layer layer2) {
        int outMapNum = layer.getOutMapNum();
        final int outMapNum2 = layer2.getOutMapNum();
        new ConcurenceRunner.TaskManager(outMapNum) { // from class: edu.hitsz.c102c.cnn.CNN.5
            @Override // edu.hitsz.c102c.util.ConcurenceRunner.TaskManager
            public void process(int i, int i2) {
                for (int i3 = i; i3 < i2; i3++) {
                    for (int i4 = 0; i4 < outMapNum2; i4++) {
                        double[][] dArr = (double[][]) null;
                        for (int i5 = 0; i5 < CNN.this.batchSize; i5++) {
                            double[][] error = layer.getError(i5, i3);
                            dArr = dArr == null ? Util.convnValid(layer2.getMap(i5, i4), error) : Util.matrixOp(Util.convnValid(layer2.getMap(i5, i4), error), dArr, null, null, Util.plus);
                        }
                        layer.setKernel(i4, i3, Util.matrixOp(layer.getKernel(i4, i3), Util.matrixOp(dArr, CNN.this.divide_batchSize), CNN.this.multiply_lambda, CNN.this.multiply_alpha, Util.plus));
                    }
                }
            }
        }.start();
    }

    private void setHiddenLayerErrors() {
        for (int i = this.layerNum - 2; i > 0; i--) {
            Layer layer = this.layers.get(i);
            Layer layer2 = this.layers.get(i + 1);
            switch (layer.getType()) {
                case conv:
                    setConvErrors(layer, layer2);
                    break;
                case samp:
                    setSampErrors(layer, layer2);
                    break;
            }
        }
    }

    private void setSampErrors(final Layer layer, final Layer layer2) {
        int outMapNum = layer.getOutMapNum();
        final int outMapNum2 = layer2.getOutMapNum();
        new ConcurenceRunner.TaskManager(outMapNum) { // from class: edu.hitsz.c102c.cnn.CNN.6
            @Override // edu.hitsz.c102c.util.ConcurenceRunner.TaskManager
            public void process(int i, int i2) {
                for (int i3 = i; i3 < i2; i3++) {
                    double[][] dArr = (double[][]) null;
                    for (int i4 = 0; i4 < outMapNum2; i4++) {
                        double[][] error = layer2.getError(i4);
                        double[][] kernel = layer2.getKernel(i3, i4);
                        dArr = dArr == null ? Util.convnFull(error, Util.rot180(kernel)) : Util.matrixOp(Util.convnFull(error, Util.rot180(kernel)), dArr, null, null, Util.plus);
                    }
                    layer.setError(i3, dArr);
                }
            }
        }.start();
    }

    private void setConvErrors(final Layer layer, final Layer layer2) {
        new ConcurenceRunner.TaskManager(layer.getOutMapNum()) { // from class: edu.hitsz.c102c.cnn.CNN.7
            @Override // edu.hitsz.c102c.util.ConcurenceRunner.TaskManager
            public void process(int i, int i2) {
                for (int i3 = i; i3 < i2; i3++) {
                    Layer.Size scaleSize = layer2.getScaleSize();
                    double[][] error = layer2.getError(i3);
                    double[][] map = layer.getMap(i3);
                    layer.setError(i3, Util.matrixOp(Util.matrixOp(map, Util.cloneMatrix(map), null, Util.one_value, Util.multiply), Util.kronecker(error, scaleSize), null, null, Util.multiply));
                }
            }
        }.start();
    }

    private boolean setOutLayerErrors(Dataset.Record record) {
        Layer layer = this.layers.get(this.layerNum - 1);
        int outMapNum = layer.getOutMapNum();
        double[] dArr = new double[outMapNum];
        double[] dArr2 = new double[outMapNum];
        for (int i = 0; i < outMapNum; i++) {
            dArr2[i] = layer.getMap(i)[0][0];
        }
        int intValue = record.getLable().intValue();
        dArr[intValue] = 1.0d;
        for (int i2 = 0; i2 < outMapNum; i2++) {
            layer.setError(i2, 0, 0, dArr2[i2] * (1.0d - dArr2[i2]) * (dArr[i2] - dArr2[i2]));
        }
        return intValue == Util.getMaxIndex(dArr2);
    }

    private void forward(Dataset.Record record) {
        setInLayerOutput(record);
        for (int i = 1; i < this.layers.size(); i++) {
            Layer layer = this.layers.get(i);
            Layer layer2 = this.layers.get(i - 1);
            switch (layer.getType()) {
                case conv:
                    setConvOutput(layer, layer2);
                    break;
                case output:
                    setConvOutput(layer, layer2);
                    break;
                case samp:
                    setSampOutput(layer, layer2);
                    break;
            }
        }
    }

    private void setInLayerOutput(Dataset.Record record) {
        Layer layer = this.layers.get(0);
        Layer.Size mapSize = layer.getMapSize();
        double[] attrs = record.getAttrs();
        if (attrs.length != mapSize.x * mapSize.y) {
            throw new RuntimeException("The size of the data record does not match the defined map size");
        }
        for (int i = 0; i < mapSize.x; i++) {
            for (int i2 = 0; i2 < mapSize.y; i2++) {
                layer.setMapValue(0, i, i2, attrs[(mapSize.x * i) + i2]);
            }
        }
    }

    private void setConvOutput(final Layer layer, final Layer layer2) {
        int outMapNum = layer.getOutMapNum();
        final int outMapNum2 = layer2.getOutMapNum();
        new ConcurenceRunner.TaskManager(outMapNum) { // from class: edu.hitsz.c102c.cnn.CNN.8
            @Override // edu.hitsz.c102c.util.ConcurenceRunner.TaskManager
            public void process(int i, int i2) {
                for (int i3 = i; i3 < i2; i3++) {
                    double[][] dArr = (double[][]) null;
                    for (int i4 = 0; i4 < outMapNum2; i4++) {
                        double[][] map = layer2.getMap(i4);
                        double[][] kernel = layer.getKernel(i4, i3);
                        dArr = dArr == null ? Util.convnValid(map, kernel) : Util.matrixOp(Util.convnValid(map, kernel), dArr, null, null, Util.plus);
                    }
                    final double bias = layer.getBias(i3);
                    layer.setMapValue(i3, Util.matrixOp(dArr, new Util.Operator() { // from class: edu.hitsz.c102c.cnn.CNN.8.1
                        private static final long serialVersionUID = 2469461972825890810L;

                        @Override // edu.hitsz.c102c.util.Util.Operator
                        public double process(double d) {
                            return Util.sigmod(d + bias);
                        }
                    }));
                }
            }
        }.start();
    }

    private void setSampOutput(final Layer layer, final Layer layer2) {
        new ConcurenceRunner.TaskManager(layer2.getOutMapNum()) { // from class: edu.hitsz.c102c.cnn.CNN.9
            @Override // edu.hitsz.c102c.util.ConcurenceRunner.TaskManager
            public void process(int i, int i2) {
                for (int i3 = i; i3 < i2; i3++) {
                    layer.setMapValue(i3, Util.scaleMatrix(layer2.getMap(i3), layer.getScaleSize()));
                }
            }
        }.start();
    }

    public void setup(int i) {
        this.layers.get(0).initOutmaps(i);
        for (int i2 = 1; i2 < this.layers.size(); i2++) {
            Layer layer = this.layers.get(i2);
            Layer layer2 = this.layers.get(i2 - 1);
            int outMapNum = layer2.getOutMapNum();
            switch (layer.getType()) {
                case conv:
                    layer.setMapSize(layer2.getMapSize().subtract(layer.getKernelSize(), 1));
                    layer.initKernel(outMapNum);
                    layer.initBias(outMapNum);
                    layer.initErros(i);
                    layer.initOutmaps(i);
                    break;
                case output:
                    layer.initOutputKerkel(outMapNum, layer2.getMapSize());
                    layer.initBias(outMapNum);
                    layer.initErros(i);
                    layer.initOutmaps(i);
                    break;
                case samp:
                    layer.setOutMapNum(outMapNum);
                    layer.setMapSize(layer2.getMapSize().divide(layer.getScaleSize()));
                    layer.initErros(i);
                    layer.initOutmaps(i);
                    break;
            }
        }
    }

    public void saveModel(String str) {
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(str));
            objectOutputStream.writeObject(this);
            objectOutputStream.flush();
            objectOutputStream.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static CNN loadModel(String str) {
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(str));
            CNN cnn = (CNN) objectInputStream.readObject();
            objectInputStream.close();
            return cnn;
        } catch (IOException | ClassNotFoundException e) {
            e.printStackTrace();
            return null;
        }
    }
}
