/*
 * Decompiled with CFR 0.152.
 */
package org.encog.mathutil.randomize;

import org.encog.engine.network.activation.ActivationReLU;
import org.encog.mathutil.matrices.Matrix;
import org.encog.mathutil.randomize.Randomizer;
import org.encog.mathutil.randomize.generate.GenerateRandom;
import org.encog.mathutil.randomize.generate.MersenneTwisterGenerateRandom;
import org.encog.ml.MLMethod;
import org.encog.neural.networks.BasicNetwork;

public class XaiverRandomizer
implements Randomizer {
    private double y2;
    private boolean useLast = false;
    private GenerateRandom rnd;

    public XaiverRandomizer() {
        this(System.currentTimeMillis());
    }

    public XaiverRandomizer(long seed) {
        this.rnd = new MersenneTwisterGenerateRandom(seed);
    }

    @Override
    public double randomize(double d) {
        return this.rnd.nextDouble();
    }

    public void randomize(BasicNetwork network, int fromLayer) {
        int fromCount = network.getLayerNeuronCount(fromLayer);
        int toCount = network.getLayerNeuronCount(fromLayer + 1);
        int fromNeuron = 0;
        while (fromNeuron < fromCount) {
            int toNeuron = 0;
            while (toNeuron < toCount) {
                network.setWeight(fromLayer, fromCount, toNeuron, 0.0);
                ++toNeuron;
            }
            toNeuron = 0;
            while (toNeuron < toCount) {
                double d = network.getActivation(fromLayer) instanceof ActivationReLU ? 2.0 / Math.sqrt(fromCount) : 2.0 / Math.sqrt(fromCount + toCount);
                double w = this.rnd.nextDouble(-d, d);
                network.setWeight(fromLayer, fromNeuron, toNeuron, w);
                ++toNeuron;
            }
            ++fromNeuron;
        }
    }

    @Override
    public void randomize(MLMethod method) {
        BasicNetwork network = (BasicNetwork)method;
        int i = 0;
        while (i < network.getLayerCount() - 1) {
            this.randomize(network, i);
            ++i;
        }
    }

    @Override
    public void randomize(double[] d) {
        this.randomize(d, 0, d.length);
    }

    @Override
    public void randomize(double[][] d) {
        int i = 0;
        while (i < d.length) {
            int j = 0;
            while (j < d[j].length) {
                d[i][j] = this.rnd.nextDouble();
                ++j;
            }
            ++i;
        }
    }

    @Override
    public void randomize(Matrix m) {
        this.randomize(m.getData());
    }

    @Override
    public void randomize(double[] d, int begin, int size) {
        int i = 0;
        while (i < size) {
            d[begin + i] = this.rnd.nextDouble();
            ++i;
        }
    }

    @Override
    public void setRandom(GenerateRandom theRandom) {
        this.rnd = theRandom;
    }

    @Override
    public GenerateRandom getRandom() {
        return this.rnd;
    }
}

