package org.ea.javacnn;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.ea.javacnn.data.BackPropResult;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.layers.Layer;
import org.ea.javacnn.losslayers.LossLayer;

/* loaded from: input_file:org/ea/javacnn/JavaCNN.class */
public class JavaCNN implements Serializable {
    private static final long serialVersionUID = 1;
    private List<Layer> layers;

    public JavaCNN(List<Layer> list) {
        this.layers = list;
    }

    public DataBlock forward(DataBlock dataBlock, boolean z) {
        DataBlock forward = this.layers.get(0).forward(dataBlock, z);
        for (int i = 1; i < this.layers.size(); i++) {
            forward = this.layers.get(i).forward(forward, z);
        }
        return forward;
    }

    public double getCostLoss(DataBlock dataBlock, int i) {
        forward(dataBlock, false);
        return ((LossLayer) this.layers.get(this.layers.size() - 1)).backward(i);
    }

    public double backward(int i) {
        int size = this.layers.size();
        double backward = ((LossLayer) this.layers.get(size - 1)).backward(i);
        for (int i2 = size - 2; i2 >= 0; i2--) {
            this.layers.get(i2).backward();
        }
        return backward;
    }

    public List<BackPropResult> getBackPropagationResult() {
        ArrayList arrayList = new ArrayList();
        Iterator<Layer> it = this.layers.iterator();
        while (it.hasNext()) {
            arrayList.addAll(it.next().getBackPropagationResult());
        }
        return arrayList;
    }

    public int getPrediction() {
        double[] weights = ((LossLayer) this.layers.get(this.layers.size() - 1)).getOutAct().getWeights();
        double d = weights[0];
        int i = 0;
        for (int i2 = 1; i2 < weights.length; i2++) {
            if (weights[i2] > d) {
                d = weights[i2];
                i = i2;
            }
        }
        return i;
    }

    public String getPredictions(int[] iArr, int[] iArr2, int i, int i2) {
        int i3 = 0;
        for (int i4 = 0; i4 < i2; i4++) {
            System.out.println("Number " + i4 + " has predictions " + iArr[i4] + "/" + iArr2[i4] + "\t\t" + ((iArr[i4] / iArr2[i4]) * 100.0f) + "%");
            i3 += iArr[i4];
        }
        return "Total correct predictions " + i3 + "/" + i + "\t\t" + ((i3 / i) * 100.0f) + "%";
    }

    public void saveModel(String str) {
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(str));
            objectOutputStream.writeObject(this);
            objectOutputStream.flush();
            objectOutputStream.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static JavaCNN loadModel(String str) {
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(str));
            JavaCNN javaCNN = (JavaCNN) objectInputStream.readObject();
            objectInputStream.close();
            return javaCNN;
        } catch (IOException | ClassNotFoundException e) {
            e.printStackTrace();
            return null;
        }
    }
}
