/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.convolution.test;

import java.util.Arrays;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class ConvolutionTests {
    private static Logger log = LoggerFactory.getLogger(ConvolutionTests.class);

    @Test
    public void convNTest() {
        INDArray arr = Nd4j.linspace(1, 8, 8);
        INDArray kernel = Nd4j.linspace(1, 3, 3);
        INDArray answer = Nd4j.create(new double[]{10.0, 16.0, 22.0, 28.0, 34.0, 40.0});
        INDArray test = Convolution.convn(arr, kernel, Convolution.Type.VALID);
        Assert.assertEquals((Object)answer, (Object)test);
    }

    @Test
    public void testConv2d() {
        INDArray input = Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[]{2, 2, 2});
        INDArray kernel = input.dup();
        INDArray convolution = Convolution.conv2d(input, kernel, Convolution.Type.FULL);
        Assert.assertTrue((boolean)Arrays.equals(new int[]{3, 3}, convolution.shape()));
        INDArray input2 = Nd4j.create(Nd4j.linspace(1, 16, 16).data(), new int[]{2, 4, 2});
        INDArray kernel2 = input2.dup();
        INDArray convolution2 = Convolution.conv2d(input2, kernel2, Convolution.Type.VALID);
        Assert.assertTrue((boolean)Arrays.equals(new int[]{2, 4}, convolution2.shape()));
    }

    @Test
    public void testConvolution() {
        INDArray image = Nd4j.create(new double[][]{{3.0, 2.0, 5.0, 6.0, 7.0, 8.0}, {5.0, 4.0, 2.0, 10.0, 8.0, 1.0}});
        INDArray kernel = Nd4j.create(new double[][]{{4.0, 5.0}, {1.0, 2.0}});
        log.info(Convolution.convn(image, kernel, Convolution.Type.FULL).toString());
        log.info(Convolution.convn(image, kernel, Convolution.Type.VALID).toString());
    }
}

