package org.neuroph.samples.convolution.util;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import javax.swing.ImageIcon;
import javax.swing.JFrame;
import javax.swing.JLabel;
import org.neuroph.core.Neuron;
import org.neuroph.nnet.comp.Dimension2D;
import org.neuroph.nnet.comp.layer.FeatureMapLayer;
import org.neuroph.nnet.comp.layer.FeatureMapsLayer;

/* loaded from: input_file:org/neuroph/samples/convolution/util/LayerVisialize.class */
public class LayerVisialize {
    private static final int RATIO = 20;
    private List<List<Double>> layerMaps = new ArrayList();
    private FeatureMapsLayer featureMapsLayer;
    private Dimension2D mapDimensions;

    public LayerVisialize(FeatureMapsLayer featureMapsLayer) {
        this.featureMapsLayer = featureMapsLayer;
        this.mapDimensions = featureMapsLayer.getMapDimensions();
        initWeights();
    }

    private void initWeights() {
        for (FeatureMapLayer featureMapLayer : this.featureMapsLayer.getFeatureMaps()) {
            ArrayList arrayList = new ArrayList();
            Iterator<Neuron> it = featureMapLayer.getNeurons().iterator();
            while (it.hasNext()) {
                arrayList.add(Double.valueOf(it.next().getOutput()));
            }
            this.layerMaps.add(arrayList);
        }
    }

    public void displayWeights() {
        Iterator<List<Double>> it = this.layerMaps.iterator();
        while (it.hasNext()) {
            displayWeight(it.next());
        }
    }

    private void displayWeight(List<Double> list) {
        JFrame jFrame = new JFrame("Weight Visualiser: ");
        jFrame.setSize(400, 400);
        JLabel jLabel = new JLabel();
        Dimension dimension = new Dimension(this.mapDimensions.getWidth() * RATIO, this.mapDimensions.getHeight() * RATIO);
        jLabel.setSize(dimension);
        jLabel.setPreferredSize(dimension);
        jFrame.getContentPane().add(jLabel, "Center");
        jFrame.pack();
        jFrame.setVisible(true);
        BufferedImage bufferedImage = new BufferedImage(this.mapDimensions.getWidth(), this.mapDimensions.getHeight(), 10);
        bufferedImage.setRGB(0, 0, this.mapDimensions.getWidth(), this.mapDimensions.getHeight(), convertWeightToRGB(list), 0, this.mapDimensions.getWidth());
        jLabel.setIcon(new ImageIcon(bufferedImage.getScaledInstance(this.mapDimensions.getWidth() * RATIO, this.mapDimensions.getHeight() * RATIO, 4)));
    }

    private int[] convertWeightToRGB(List<Double> list) {
        normalizeWeights(list);
        int[] iArr = new int[this.mapDimensions.getWidth() * this.mapDimensions.getHeight()];
        int i = 0;
        Iterator<Double> it = list.iterator();
        while (it.hasNext()) {
            int doubleValue = (int) (it.next().doubleValue() * 255.0d);
            int i2 = i;
            i++;
            iArr[i2] = new Color(doubleValue, doubleValue, doubleValue).getRGB();
        }
        return iArr;
    }

    private void normalizeWeights(List<Double> list) {
        double d = Double.MAX_VALUE;
        double d2 = Double.MIN_VALUE;
        for (Double d3 : list) {
            d = Math.min(d, d3.doubleValue());
            d2 = Math.max(d2, d3.doubleValue());
        }
        for (int i = 0; i < list.size(); i++) {
            list.set(i, Double.valueOf((list.get(i).doubleValue() - d) / (d2 - d)));
        }
    }
}
