/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.activation;

import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.activation.Activations;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class SoftMaxTest {
    private static Logger log = LoggerFactory.getLogger(SoftMaxTest.class);

    @Test
    public void testSoftMax() {
        Nd4j.factory().setOrder('f');
        INDArray test = Nd4j.linspace(1, 6, 6).reshape(2, 3);
        INDArray softMaxColumns = (INDArray)Activations.softmax().apply(test);
        INDArray softMaxRows = (INDArray)Activations.softMaxRows().apply(test);
        INDArray columns = softMaxColumns.sum(0);
        INDArray rows = softMaxRows.sum(1);
        Assert.assertEquals((double)3.0, (double)columns.sum(Integer.MAX_VALUE).getFloat(0), (double)0.1);
        Assert.assertEquals((double)2.0, (double)rows.sum(Integer.MAX_VALUE).getFloat(0), (double)0.1);
    }

    @Test
    public void testSoftMaxCOrder() {
        Nd4j.factory().setOrder('c');
        INDArray test = Nd4j.linspace(1, 6, 6).reshape(2, 3);
        INDArray softMaxColumns = (INDArray)Activations.softmax().apply(test);
        INDArray softMaxRows = (INDArray)Activations.softMaxRows().apply(test);
        INDArray columns = softMaxColumns.sum(0);
        INDArray rows = softMaxRows.sum(1);
        Assert.assertEquals((double)3.0, (double)columns.sum(Integer.MAX_VALUE).getFloat(0), (double)0.1);
        Assert.assertEquals((double)2.0, (double)rows.sum(Integer.MAX_VALUE).getFloat(0), (double)0.1);
    }
}

