/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.samples.convolution.util;

import java.awt.Color;
import java.awt.Component;
import java.awt.Dimension;
import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.List;
import javax.swing.ImageIcon;
import javax.swing.JFrame;
import javax.swing.JLabel;
import org.neuroph.core.Connection;
import org.neuroph.core.Neuron;
import org.neuroph.nnet.comp.Kernel;
import org.neuroph.nnet.comp.layer.FeatureMapLayer;
import org.neuroph.nnet.comp.neuron.BiasNeuron;

public class WeightVisualiser {
    private static final int RATIO = 20;
    private List<List<Double>> featureDetector;
    private Kernel kernel;

    public WeightVisualiser(FeatureMapLayer map, Kernel kernel) {
        this.kernel = kernel;
        this.featureDetector = new ArrayList<List<Double>>();
        this.initWeights(map);
    }

    private void initWeights(FeatureMapLayer map) {
        ArrayList<Double> weights = new ArrayList<Double>();
        Neuron neuron = map.getNeuronAt(0);
        int counter = 0;
        for (Connection conn : neuron.getInputConnections()) {
            if (conn.getFromNeuron() instanceof BiasNeuron) continue;
            if (counter < this.kernel.getArea()) {
                weights.add(conn.getWeight().getValue());
                ++counter;
                continue;
            }
            this.featureDetector.add(weights);
            weights = new ArrayList();
            weights.add(conn.getWeight().getValue());
            counter = 1;
        }
        this.featureDetector.add(weights);
    }

    public void displayWeights() {
        for (List<Double> currentKernel : this.featureDetector) {
            this.displayWeight(currentKernel);
        }
    }

    private void displayWeight(List<Double> currentKernel) {
        JFrame frame = new JFrame("Weight Visualiser: ");
        frame.setSize(400, 400);
        JLabel label = new JLabel();
        Dimension d = new Dimension(this.kernel.getWidth() * 20, this.kernel.getHeight() * 20);
        label.setSize(d);
        label.setPreferredSize(d);
        frame.getContentPane().add((Component)label, "Center");
        frame.pack();
        frame.setVisible(true);
        BufferedImage image = new BufferedImage(this.kernel.getWidth(), this.kernel.getHeight(), 10);
        int[] rgb = this.convertWeightToRGB(currentKernel);
        image.setRGB(0, 0, this.kernel.getWidth(), this.kernel.getHeight(), rgb, 0, this.kernel.getWidth());
        label.setIcon(new ImageIcon(image.getScaledInstance(this.kernel.getWidth() * 20, this.kernel.getHeight() * 20, 4)));
    }

    private int[] convertWeightToRGB(List<Double> weights) {
        this.normalizeWeights(weights);
        int[] data = new int[this.kernel.getWidth() * this.kernel.getHeight()];
        int i = 0;
        for (Double weight : weights) {
            int val = (int)(weight * 255.0);
            data[i++] = new Color(val, val, val).getRGB();
        }
        return data;
    }

    private void normalizeWeights(List<Double> weights) {
        double min = Double.MAX_VALUE;
        double max = Double.MIN_VALUE;
        for (Double weight : weights) {
            min = Math.min(min, weight);
            max = Math.max(max, weight);
        }
        for (int i = 0; i < weights.size(); ++i) {
            double value = (weights.get(i) - min) / (max - min);
            weights.set(i, value);
        }
    }
}

