/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.prune;

import org.encog.mathutil.randomize.BasicRandomizer;
import org.encog.mathutil.randomize.Distort;
import org.encog.mathutil.randomize.RangeRandomizer;
import org.encog.neural.NeuralNetworkError;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.BasicNetwork;
import org.encog.util.EngineArray;

public class PruneSelective {
    private final BasicNetwork network;

    public PruneSelective(BasicNetwork network) {
        this.network = network;
    }

    public void changeNeuronCount(int layer, int neuronCount) {
        if (neuronCount == 0) {
            throw new NeuralNetworkError("Can't decrease to zero neurons.");
        }
        int currentCount = this.network.getLayerNeuronCount(layer);
        if (neuronCount == currentCount) {
            return;
        }
        if (neuronCount > currentCount) {
            this.increaseNeuronCount(layer, neuronCount);
        } else {
            this.decreaseNeuronCount(layer, neuronCount);
        }
    }

    private void decreaseNeuronCount(int layer, int neuronCount) {
        int lostNeuronCount = this.network.getLayerNeuronCount(layer) - neuronCount;
        int[] lostNeuron = this.findWeakestNeurons(layer, lostNeuronCount);
        int i = 0;
        while (i < lostNeuronCount) {
            this.prune(layer, lostNeuron[i] - i);
            ++i;
        }
    }

    public double determineNeuronSignificance(int layer, int neuron) {
        int i;
        this.network.validateNeuron(layer, neuron);
        double result = 0.0;
        if (layer > 0) {
            int prevLayer = layer - 1;
            int prevCount = this.network.getLayerTotalNeuronCount(prevLayer);
            i = 0;
            while (i < prevCount) {
                result += this.network.getWeight(prevLayer, i, neuron);
                ++i;
            }
        }
        if (layer < this.network.getLayerCount() - 1) {
            int nextLayer = layer + 1;
            int nextCount = this.network.getLayerNeuronCount(nextLayer);
            i = 0;
            while (i < nextCount) {
                result += this.network.getWeight(layer, neuron, i);
                ++i;
            }
        }
        return Math.abs(result);
    }

    private int[] findWeakestNeurons(int layer, int count) {
        double[] lostNeuronSignificance = new double[count];
        int[] lostNeuron = new int[count];
        int i = 0;
        while (i < count) {
            lostNeuron[i] = i;
            lostNeuronSignificance[i] = this.determineNeuronSignificance(layer, i);
            ++i;
        }
        i = count;
        while (i < this.network.getLayerNeuronCount(layer)) {
            double significance = this.determineNeuronSignificance(layer, i);
            int j = 0;
            while (j < count) {
                if (lostNeuronSignificance[j] > significance) {
                    lostNeuron[j] = i;
                    lostNeuronSignificance[j] = significance;
                    break;
                }
                ++j;
            }
            ++i;
        }
        return lostNeuron;
    }

    public BasicNetwork getNetwork() {
        return this.network;
    }

    private void increaseNeuronCount(int targetLayer, int neuronCount) {
        if (targetLayer > this.network.getLayerCount()) {
            throw new NeuralNetworkError("Invalid layer " + targetLayer);
        }
        if (neuronCount <= 0) {
            throw new NeuralNetworkError("Invalid neuron count " + neuronCount);
        }
        int oldNeuronCount = this.network.getLayerNeuronCount(targetLayer);
        int increaseBy = neuronCount - oldNeuronCount;
        if (increaseBy <= 0) {
            throw new NeuralNetworkError("New neuron count is either a decrease or no change: " + neuronCount);
        }
        FlatNetwork flat = this.network.getStructure().getFlat();
        double[] oldWeights = flat.getWeights();
        int connections = oldWeights.length;
        int inBoundConnections = 0;
        int outBoundConnections = 0;
        if (targetLayer > 0) {
            inBoundConnections = this.network.getLayerTotalNeuronCount(targetLayer - 1);
            connections += inBoundConnections * increaseBy;
        }
        if (targetLayer < this.network.getLayerCount() - 1) {
            outBoundConnections = this.network.getLayerNeuronCount(targetLayer + 1);
            connections += outBoundConnections * increaseBy;
        }
        int flatLayer = this.network.getLayerCount() - targetLayer - 1;
        int[] nArray = flat.getLayerCounts();
        int n = flatLayer;
        nArray[n] = nArray[n] + increaseBy;
        int[] nArray2 = flat.getLayerFeedCounts();
        int n2 = flatLayer;
        nArray2[n2] = nArray2[n2] + increaseBy;
        double[] newWeights = new double[connections];
        int weightsIndex = 0;
        int oldWeightsIndex = 0;
        int fromLayer = flat.getLayerCounts().length - 2;
        while (fromLayer >= 0) {
            int fromNeuronCount = this.network.getLayerTotalNeuronCount(fromLayer);
            int toNeuronCount = this.network.getLayerNeuronCount(fromLayer + 1);
            int toLayer = fromLayer + 1;
            int toNeuron = 0;
            while (toNeuron < toNeuronCount) {
                int fromNeuron = 0;
                while (fromNeuron < fromNeuronCount) {
                    newWeights[weightsIndex++] = toLayer == targetLayer && toNeuron >= oldNeuronCount ? 0.0 : (fromLayer == targetLayer && fromNeuron > oldNeuronCount ? 0.0 : this.network.getFlat().getWeights()[oldWeightsIndex++]);
                    ++fromNeuron;
                }
                ++toNeuron;
            }
            --fromLayer;
        }
        flat.setWeights(newWeights);
        this.reindexNetwork();
    }

    public void prune(int targetLayer, int neuron) {
        this.network.validateNeuron(targetLayer, neuron);
        if (this.network.getLayerNeuronCount(targetLayer) <= 1) {
            throw new NeuralNetworkError("A layer must have at least a single neuron.  If you want to remove the entire layer you must create a new network.");
        }
        FlatNetwork flat = this.network.getStructure().getFlat();
        double[] oldWeights = flat.getWeights();
        int connections = oldWeights.length;
        int inBoundConnections = 0;
        int outBoundConnections = 0;
        if (targetLayer > 0) {
            inBoundConnections = this.network.getLayerTotalNeuronCount(targetLayer - 1);
            connections -= inBoundConnections;
        }
        if (targetLayer < this.network.getLayerCount() - 1) {
            outBoundConnections = this.network.getLayerNeuronCount(targetLayer + 1);
            connections -= outBoundConnections;
        }
        double[] newWeights = new double[connections];
        int weightsIndex = 0;
        int fromLayer = flat.getLayerCounts().length - 2;
        while (fromLayer >= 0) {
            int fromNeuronCount = this.network.getLayerTotalNeuronCount(fromLayer);
            int toNeuronCount = this.network.getLayerNeuronCount(fromLayer + 1);
            int toLayer = fromLayer + 1;
            int toNeuron = 0;
            while (toNeuron < toNeuronCount) {
                int fromNeuron = 0;
                while (fromNeuron < fromNeuronCount) {
                    boolean skip = false;
                    if (toLayer == targetLayer && toNeuron == neuron) {
                        skip = true;
                    } else if (fromLayer == targetLayer && fromNeuron == neuron) {
                        skip = true;
                    }
                    if (!skip) {
                        newWeights[weightsIndex++] = this.network.getWeight(fromLayer, fromNeuron, toNeuron);
                    }
                    ++fromNeuron;
                }
                ++toNeuron;
            }
            --fromLayer;
        }
        flat.setWeights(newWeights);
        int flatLayer = this.network.getLayerCount() - targetLayer - 1;
        int[] nArray = flat.getLayerCounts();
        int n = flatLayer;
        nArray[n] = nArray[n] - 1;
        int[] nArray2 = flat.getLayerFeedCounts();
        int n2 = flatLayer;
        nArray2[n2] = nArray2[n2] - 1;
        this.reindexNetwork();
    }

    public void randomizeNeuron(double low, double high, int targetLayer, int neuron) {
        this.randomizeNeuron(targetLayer, neuron, true, low, high, false, 0.0);
    }

    public void randomizeNeuron(int targetLayer, int neuron) {
        FlatNetwork flat = this.network.getStructure().getFlat();
        double low = EngineArray.min(flat.getWeights());
        double high = EngineArray.max(flat.getWeights());
        this.randomizeNeuron(targetLayer, neuron, true, low, high, false, 0.0);
    }

    private void randomizeNeuron(int targetLayer, int neuron, boolean useRange, double low, double high, boolean usePercent, double percent) {
        BasicRandomizer d = useRange ? new RangeRandomizer(low, high) : new Distort(percent);
        this.network.validateNeuron(targetLayer, neuron);
        FlatNetwork flat = this.network.getStructure().getFlat();
        double[] newWeights = new double[flat.getWeights().length];
        int weightsIndex = 0;
        int fromLayer = flat.getLayerCounts().length - 2;
        while (fromLayer >= 0) {
            int fromNeuronCount = this.network.getLayerTotalNeuronCount(fromLayer);
            int toNeuronCount = this.network.getLayerNeuronCount(fromLayer + 1);
            int toLayer = fromLayer + 1;
            int toNeuron = 0;
            while (toNeuron < toNeuronCount) {
                int fromNeuron = 0;
                while (fromNeuron < fromNeuronCount) {
                    boolean randomize = false;
                    if (toLayer == targetLayer && toNeuron == neuron) {
                        randomize = true;
                    } else if (fromLayer == targetLayer && fromNeuron == neuron) {
                        randomize = true;
                    }
                    double weight = this.network.getWeight(fromLayer, fromNeuron, toNeuron);
                    if (randomize) {
                        weight = d.randomize(weight);
                    }
                    newWeights[weightsIndex++] = weight;
                    ++fromNeuron;
                }
                ++toNeuron;
            }
            --fromLayer;
        }
        flat.setWeights(newWeights);
    }

    private void reindexNetwork() {
        FlatNetwork flat = this.network.getStructure().getFlat();
        int neuronCount = 0;
        int weightCount = 0;
        int i = 0;
        while (i < flat.getLayerCounts().length) {
            if (i > 0) {
                int from = flat.getLayerFeedCounts()[i - 1];
                int to = flat.getLayerCounts()[i];
                weightCount += from * to;
            }
            flat.getLayerIndex()[i] = neuronCount;
            flat.getWeightIndex()[i] = weightCount;
            neuronCount += flat.getLayerCounts()[i];
            ++i;
        }
        flat.setLayerOutput(new double[neuronCount]);
        flat.setLayerSums(new double[neuronCount]);
        flat.clearContext();
        flat.setInputCount(flat.getLayerFeedCounts()[flat.getLayerCounts().length - 1]);
        flat.setOutputCount(flat.getLayerFeedCounts()[0]);
    }

    public void stimulateNeuron(double percent, int targetLayer, int neuron) {
        this.randomizeNeuron(targetLayer, neuron, false, 0.0, 0.0, true, percent);
    }

    public void stimulateWeakNeurons(int layer, int count, double percent) {
        int[] weak;
        int[] nArray = weak = this.findWeakestNeurons(layer, count);
        int n = weak.length;
        int n2 = 0;
        while (n2 < n) {
            int element = nArray[n2];
            this.stimulateNeuron(percent, layer, element);
            ++n2;
        }
    }
}

