/*
 * Decompiled with CFR 0.152.
 */
package org.ea.javacnn;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.ea.javacnn.data.BackPropResult;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.layers.Layer;
import org.ea.javacnn.losslayers.LossLayer;

public class JavaCNN
implements Serializable {
    private static final long serialVersionUID = 1L;
    private List<Layer> layers;

    public JavaCNN(List<Layer> layers) {
        this.layers = layers;
    }

    public DataBlock forward(DataBlock db, boolean training) {
        DataBlock act = this.layers.get(0).forward(db, training);
        for (int i = 1; i < this.layers.size(); ++i) {
            act = this.layers.get(i).forward(act, training);
        }
        return act;
    }

    public double getCostLoss(DataBlock db, int y) {
        this.forward(db, false);
        int N = this.layers.size();
        double loss = ((LossLayer)this.layers.get(N - 1)).backward(y);
        return loss;
    }

    public double backward(int y) {
        int N = this.layers.size();
        double loss = ((LossLayer)this.layers.get(N - 1)).backward(y);
        for (int i = N - 2; i >= 0; --i) {
            this.layers.get(i).backward();
        }
        return loss;
    }

    public List<BackPropResult> getBackPropagationResult() {
        ArrayList<BackPropResult> result = new ArrayList<BackPropResult>();
        for (Layer l : this.layers) {
            List<BackPropResult> subResult = l.getBackPropagationResult();
            result.addAll(subResult);
        }
        return result;
    }

    public int getPrediction() {
        LossLayer S = (LossLayer)this.layers.get(this.layers.size() - 1);
        double[] p = S.getOutAct().getWeights();
        double maxv = p[0];
        int maxi = 0;
        for (int i = 1; i < p.length; ++i) {
            if (!(p[i] > maxv)) continue;
            maxv = p[i];
            maxi = i;
        }
        return maxi;
    }

    public String getPredictions(int[] correctPredictions, int[] numberDistribution, int totalSize, int numOfClasses) {
        int sumCorrectPredictions = 0;
        for (int i = 0; i < numOfClasses; ++i) {
            StringBuilder sb = new StringBuilder();
            sb.append("Number ");
            sb.append(i);
            sb.append(" has predictions ");
            sb.append(correctPredictions[i]);
            sb.append("/");
            sb.append(numberDistribution[i]);
            sb.append("\t\t");
            sb.append((float)correctPredictions[i] / (float)numberDistribution[i] * 100.0f);
            sb.append("%");
            System.out.println(sb.toString());
            sumCorrectPredictions += correctPredictions[i];
        }
        StringBuilder sb = new StringBuilder();
        sb.append("Total correct predictions ");
        sb.append(sumCorrectPredictions);
        sb.append("/");
        sb.append(totalSize);
        sb.append("\t\t");
        sb.append((float)sumCorrectPredictions / (float)totalSize * 100.0f);
        sb.append("%");
        return sb.toString();
    }

    public void saveModel(String fileName) {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(fileName));
            oos.writeObject(this);
            oos.flush();
            oos.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static JavaCNN loadModel(String fileName) {
        try {
            ObjectInputStream in = new ObjectInputStream(new FileInputStream(fileName));
            JavaCNN cnn = (JavaCNN)in.readObject();
            in.close();
            return cnn;
        }
        catch (IOException | ClassNotFoundException e) {
            e.printStackTrace();
            return null;
        }
    }
}

