blob: 5c2ce10428c48fbc525e21cd59ae04bffb8ed6e2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
from __future__ import absolute_import
from __future__ import division
import os
from uuid import uuid4
import numpy as _np
import mxnet as mx
from mxnet import gluon, autograd, np, npx
from mxnet.test_utils import use_np, assert_almost_equal, check_gluon_hybridize_consistency, assert_allclose
from mxnet.gluon import nn
from mxnet.base import MXNetError
import random
import pytest
def test_create_np_param():
M, K, N = 10, 9, 20
def check_block_params(x, TestBlock, hybridize, expected_type, initializer):
net = TestBlock()
net.initialize(initializer())
if hybridize:
net.hybridize()
net(x)
params = net.collect_params()
for _, v in params.items():
assert type(v.data()) is expected_type
@use_np
class TestBlock1(gluon.HybridBlock):
def __init__(self):
super(TestBlock1, self).__init__()
self.w = gluon.Parameter('w', shape=(K, N), allow_deferred_init=True)
def forward(self, x):
device = x.device
return np.dot(x, self.w.data(device))
x = mx.np.random.uniform(size=(M, K))
for initializer in [mx.initializer.Uniform, mx.initializer.Normal]:
check_block_params(x, TestBlock1, False, mx.np.ndarray, initializer)
check_block_params(x, TestBlock1, True, mx.np.ndarray, initializer)
@use_np
def test_optimizer_with_np_ndarrays():
class LinearRegression(gluon.HybridBlock):
def __init__(self, num_input_dim=0, num_hidden_dim=100, num_output_dim=10):
super(LinearRegression, self).__init__()
self.w1 = gluon.Parameter('w1', shape=(num_input_dim, num_hidden_dim),
allow_deferred_init=True)
self.w2 = gluon.Parameter('w2', shape=(num_hidden_dim, num_output_dim),
allow_deferred_init=True)
def forward(self, x):
device = x.device
h = x.dot(self.w1.data(device)) # equivalent to np.dot(x, w1)
h_relu = npx.relu(h) # equivalent to npx.relu(h) but generating np.ndarray
y_pred = h_relu.dot(self.w2.data(device)) # equivalent to np.dot(h_relu, w2)
return y_pred
def infer_shape(self, x, *args):
pre_shape = self.w1.shape
self.w1.shape = (x.shape[x.ndim-1], pre_shape[1])
class TotalLoss(gluon.HybridBlock):
def forward(self, pred, label):
return ((pred - label) ** 2).sum() # equivalent to np.sum(np.square(pred - label))
regressor = LinearRegression()
regressor.initialize(mx.init.Uniform())
regressor.hybridize()
# Create random input and output data
x = np.random.uniform(size=(64, 1000)) # x is of type mxnet.numpy.ndarray
regressor(x)
y = np.random.uniform(size=(64, 10)) # y is of type mxnet.numpy.ndarray
total_loss = TotalLoss()
total_loss.hybridize()
trainer = gluon.Trainer(regressor.collect_params(),
'sgd',
{'learning_rate': 1e-3, 'momentum': 0.9})
for _ in range(2):
with autograd.record():
output = regressor(x) # output is a type of np.ndarray because np.dot is the last op in the network
loss = total_loss(output, y) # loss is a scalar np.ndarray
loss.backward()
trainer.step(1)
@use_np
def test_optimizer_backward_compat():
optimizer = mx.optimizer.SGD()
delattr(optimizer, "allow_np_array")
updater = mx.optimizer.Updater(optimizer)
updater(0, np.ones((0, 0)), np.zeros((0, 0)))
@use_np
def test_np_loss_ndarray():
# Ported from test_loss.test_loss_ndarray
output = np.array([1, 2, 3, 4])
label = np.array([1, 3, 5, 7])
weighting = np.array([0.5, 1, 0.5, 1])
loss = gluon.loss.L1Loss()
assert float(np.sum(loss(output, label))) == 6.
loss = gluon.loss.L1Loss(weight=0.5)
assert float(np.sum(loss(output, label))) == 3.
loss = gluon.loss.L1Loss()
assert float(np.sum(loss(output, label, weighting))) == 5.
loss = gluon.loss.L2Loss()
assert float(np.sum(loss(output, label))) == 7.
loss = gluon.loss.L2Loss(weight=0.25)
assert float(np.sum(loss(output, label))) == 1.75
loss = gluon.loss.L2Loss()
assert float(np.sum(loss(output, label, weighting))) == 6
output = np.array([[0, 2], [1, 4]])
label = np.array([0, 1])
weighting = np.array([[0.5], [1.0]])
loss = gluon.loss.SoftmaxCrossEntropyLoss()
L = loss(output, label).asnumpy()
assert_almost_equal(L, _np.array([2.12692809, 0.04858733]), use_broadcast=False, rtol=1e-3)
L = loss(output, label, weighting).asnumpy()
assert_almost_equal(L, _np.array([1.06346405, 0.04858733]), use_broadcast=False, rtol=1e-3)
@use_np
def test_np_get_constant():
const_arr = _np.random.uniform(0, 100, size=(10, 10)).astype(_np.float32)
class Foo(gluon.HybridBlock):
def __init__(self):
super(Foo, self).__init__()
self.weight = gluon.Constant(const_arr)
def forward(self, x):
device = x.device
return x + self.weight.data(device).astype(np.float32)
x = np.random.uniform(size=const_arr.shape, dtype=const_arr.dtype)
for hybridize in [False, True]:
foo = Foo()
if hybridize:
foo.hybridize()
foo.initialize()
out = foo(x)
assert_almost_equal(out.asnumpy(), (x.asnumpy() + const_arr), atol=1e-5, rtol=1e-4, use_broadcast=False)
@use_np
def test_parameters_zero_grad():
for hybridize in [False, True]:
net = gluon.nn.HybridSequential()
for _ in range(5):
net.add(gluon.nn.Dense(10))
if hybridize:
net.hybridize()
net.initialize()
out = net(mx.np.ones((32, 8)))
for v in net.collect_params().values():
v.grad()[()] = 1
net.zero_grad()
for v in net.collect_params().values():
assert_almost_equal(v.grad().asnumpy(), mx.np.zeros_like(v.grad()).asnumpy())
def check_gluon_save_load(net_builder, data_l):
"""Verify the consistency between the loaded network and the original network.
Known limitations: Currently it only supports loading
Parameters
----------
net_builder : function
The builder of the HybridBlock.
data_l : list of numpy.ndarray
List of the input data that we will use to verify the correctness of the loaded network.
"""
net = net_builder()
net.hybridize()
net.initialize()
out = net(*data_l)
out_np = out.asnumpy()
prefix = str(uuid4())
net.export(prefix)
input_names = 'data' if len(data_l) == 1 else ['data{}'.format(i) for i in range(len(data_l))]
net_imported = gluon.SymbolBlock.imports('{}-symbol.json'.format(prefix),
input_names, param_file='{}-0000.params'.format(prefix))
# Clean up the directory
os.remove('{}-symbol.json'.format(prefix))
os.remove('{}-0000.params'.format(prefix))
loaded_out = net_imported(*data_l).asnumpy()
assert_almost_equal(out_np, loaded_out)
def hashable_index(tuple_idx):
"""Return an hashable representation of a tuple of slice object
We add this because the slice object in python is not hashable.
Parameters
----------
tuple_idx : tuple
A tuple of slice/int objects
Returns
-------
ret : tuple
A hashable representation of the slice data
"""
l = []
for ele in tuple_idx:
if isinstance(ele, slice):
l.append(ele.__reduce__())
else:
l.append(ele)
return tuple(l)
@use_np
def test_symbolic_basic_slicing():
def random_slice_index(shape):
index = []
step_switch = random.randint(-1, 1)
for i in range(len(shape)):
if shape[i] == 0:
index.append(slice(None))
continue
r = random.randint(0, 5)
if r < 2:
index.append(random.randint(1 - shape[i], shape[i] - 1))
continue
elif r < 3:
index.append(slice(None))
continue
s = random.randint(0, shape[i] - 1)
e = random.randint(s + 1, shape[i])
if step_switch == 1:
index.append(slice(s, e, 1))
elif step_switch == -1:
e -= 1
s -= 1
index.append(slice(e, s, -1))
else:
index.append(slice(s, e, 2))
return tuple(index)
shapes = [
(4, 6, 8, 5),
(1, 1, 1, 6),
(5, 6, 4),
(5, 6),
(10,),
(100, 0, 10, 0, 0),
(100, 0, 0),
(0, 0, 0),
(),
]
for shape in shapes:
cache = set()
x = mx.np.random.normal(0, 1, shape)
y = mx.np.random.normal(0, 1, shape)
for _ in range(200):
index1 = random_slice_index(shape)
index2 = random_slice_index(x.asnumpy()[index1].shape)
if (hashable_index(index1), hashable_index(index2)) in cache:
continue
cache.add((hashable_index(index1), hashable_index(index2)))
# Test basic slicing on a single symbol
class TestSlicingSingleSymbol1(gluon.HybridBlock):
def forward(self, x, y):
return x[()][index1] + y[()][index1]
# Test basic slicing on a single symbol
class TestSlicingSingleSymbol2(gluon.HybridBlock):
def forward(self, x, y):
return (x[()][index1] + y[()][index1])[index2]
check_gluon_hybridize_consistency(TestSlicingSingleSymbol1, [x, y],
numpy_func=lambda a, b: a[()][index1] + b[()][index1])
check_gluon_hybridize_consistency(TestSlicingSingleSymbol2, [x, y],
numpy_func=lambda a, b:
(a[()][index1] + b[()][index1])[index2])
# Test for split/hsplit/vsplit
class TestSlicingWithSplit(gluon.HybridBlock):
def forward(self, x):
x = mx.np.split(x, shape[2], axis=2)
x = x[1:-1]
x = mx.np.concatenate(x, axis=2)
return x
class TestSlicingWithSplit2(gluon.HybridBlock):
def __init__(self):
super(TestSlicingWithSplit2, self).__init__()
self.layer = gluon.nn.Dense(16, flatten=False)
def forward(self, x, y):
x = mx.np.split(x, 1)
x = x[0]
return self.layer(x[:, -1, :] + y[:, -1, :])
class TestSlicingWithHSplit(gluon.HybridBlock):
def forward(self, x):
x = mx.np.hsplit(x, shape[1])
x = x[1:-1]
x = mx.np.concatenate(x, axis=1)
return x
class TestSlicingWithVSplit(gluon.HybridBlock):
def forward(self, x):
x = mx.np.vsplit(x, shape[0])
x = x[1:-1]
x = mx.np.concatenate(x, axis=0)
return x
if len(shape) > 2 and shape[2] > 2:
check_gluon_hybridize_consistency(
TestSlicingWithSplit, [x],
numpy_func=lambda a: _np.concatenate(_np.split(a, shape[2], axis=2)[1:-1],
axis=2))
if len(shape) == 3 and shape[0] > 0 and shape[1] > 0 and shape[2] > 0:
check_gluon_hybridize_consistency(TestSlicingWithSplit2, [x, y])
if len(shape) > 1 and shape[1] > 2:
check_gluon_hybridize_consistency(
TestSlicingWithHSplit, [x],
numpy_func=lambda a: _np.concatenate(_np.hsplit(a, shape[1])[1:-1], axis=1))
if len(shape) > 1 and shape[0] > 2:
check_gluon_hybridize_consistency(
TestSlicingWithVSplit, [x],
numpy_func=lambda a: _np.concatenate(_np.vsplit(a, shape[0])[1:-1], axis=0))
for data_shape, idx in [((4, 3), 2),
((3,), -1),
((3,), 0)]:
class IntegerIndexing(gluon.HybridBlock):
def forward(self, x):
return x[idx]
check_gluon_hybridize_consistency(IntegerIndexing,
[mx.np.ones(data_shape)],
numpy_func=lambda a: a[idx])
@use_np
def test_net_symbol_save_load():
class Case1(gluon.HybridBlock):
def __init__(self):
super(Case1, self).__init__()
self.layer = gluon.nn.Dense(64, flatten=False)
def forward(self, x, y):
x = mx.np.split(x, 1)
x = x[0]
return self.layer(x[:, -1, :] + y[:, -1, :])
check_gluon_save_load(Case1, [mx.np.random.normal(0, 1, (10, 5, 8, 6)),
mx.np.random.normal(0, 1, (10, 5, 8, 6))])
class Case2(gluon.HybridBlock):
def __init__(self):
super(Case2, self).__init__()
self.layer1 = gluon.nn.Dense(64, flatten=False)
self.layer2 = gluon.nn.Dense(64, flatten=False)
def forward(self, x, y):
x = mx.np.split(x, 1)
x = x[0]
return self.layer1(x[:, -1, :]) + self.layer2(y[:, -1, :])
check_gluon_save_load(Case2, [mx.np.random.normal(0, 1, (10, 5, 8)),
mx.np.random.normal(0, 1, (10, 5, 8))])
@use_np
def test_hybridize_boolean_dtype():
class Foo(gluon.HybridBlock):
def __init__(self):
super(Foo, self).__init__()
def forward(self, valid_length):
mask = ((np.ones((10,)) / 2) < valid_length)
return mask
valid_length = mx.np.random.uniform(size=(10,))
foo = Foo()
out1 = foo(valid_length)
foo = Foo()
foo.hybridize()
out2 = foo(valid_length)
assert mx.test_utils.same(out1.asnumpy(), out2.asnumpy())
@use_np
def test_optimize_for():
class TestBlock(gluon.HybridBlock):
def __init__(self):
super(TestBlock, self).__init__()
self.d = mx.gluon.nn.Dense(1)
def forward(self, a):
res = self.d(a)
return res
a = mx.np.random.uniform(low=-1, high=1, size=(1,1))
net = TestBlock()
net.initialize()
net.hybridize()
out = net(a)
b = net.collect_params().pop('d.weight').data()
net.optimize_for(a, b, backend="ONEDNN")
out2 = net(a)
@use_np
def test_activations_leakyrelu():
# Currently, all the activation tests, we will just test for runnable.
act_layer = nn.LeakyReLU(0.1)
out = act_layer(mx.np.random.uniform(size=(10,)))
out.asnumpy()
@use_np
def test_activations_prelu():
act_layer = nn.PReLU()
act_layer.initialize()
out = act_layer(mx.np.random.uniform(size=(10,)))
out.asnumpy()
@use_np
def test_activations_elu():
act_layer = nn.ELU(1.0)
out = act_layer(mx.np.random.uniform(size=(10,)))
out.asnumpy()
@use_np
def test_activations_selu():
act_layer = nn.SELU()
out = act_layer(mx.np.random.uniform(size=(10,)))
out.asnumpy()
@use_np
def test_activations_gelu():
act_layer = nn.GELU()
out = act_layer(mx.np.random.uniform(size=(10,)))
out.asnumpy()
@use_np
def test_activations_gelu_tanh():
act_layer = nn.GELU(approximation='tanh')
out = act_layer(mx.np.random.uniform(size=(10,)))
out.asnumpy()
@use_np
def test_activations_swish():
act_layer = nn.Swish()
out = act_layer(mx.np.random.uniform(size=(10,)))
out.asnumpy()
@use_np
def test_activations_silu():
act_layer = nn.SiLU()
out = act_layer(mx.np.random.uniform(size=(10,)))
out.asnumpy()
@use_np
def test_concatenate():
model = nn.HybridConcatenate(axis=1)
model.add(nn.Dense(128, activation='tanh', in_units=10))
model.add(nn.Dense(64, activation='tanh', in_units=10))
model.add(nn.Dense(32, in_units=10))
model2 = nn.Concatenate(axis=1)
model2.add(nn.Dense(128, activation='tanh', in_units=10))
model2.add(nn.Dense(64, activation='tanh', in_units=10))
model2.add(nn.Dense(32, in_units=10))
# ndarray
model.initialize(mx.init.Xavier(magnitude=2.24))
model2.initialize(mx.init.Xavier(magnitude=2.24))
x = model(mx.np.zeros((32, 10)))
x2 = model2(mx.np.zeros((32, 10)))
assert x.shape == (32, 224)
assert x2.shape == (32, 224)
x.wait_to_read()
x2.wait_to_read()
@use_np
def test_identity():
model = nn.Identity()
x = mx.np.random.uniform(size=(128, 33, 64))
assert_almost_equal(model(x), x)
@use_np
def test_pixelshuffle1d():
nchan = 2
up_x = 2
nx = 3
shape_before = (1, nchan * up_x, nx)
shape_after = (1, nchan, nx * up_x)
layer = nn.PixelShuffle1D(up_x)
x = mx.np.reshape(mx.np.arange(_np.prod(shape_before)), shape_before)
y = layer(x)
assert y.shape == shape_after
assert_allclose(
y,
[[[0, 3, 1, 4, 2, 5],
[6, 9, 7, 10, 8, 11]]]
)
@use_np
def test_pixelshuffle2d():
nchan = 2
up_x = 2
up_y = 3
nx = 2
ny = 3
shape_before = (1, nchan * up_x * up_y, nx, ny)
shape_after = (1, nchan, nx * up_x, ny * up_y)
layer = nn.PixelShuffle2D((up_x, up_y))
x = mx.np.reshape(mx.np.arange(_np.prod(shape_before)), shape_before)
y = layer(x)
assert y.shape == shape_after
# - Channels are reshaped to form 2x3 blocks
# - Within each block, the increment is `nx * ny` when increasing the column
# index by 1
# - Increasing the block index adds an offset of 1
# - Increasing the channel index adds an offset of `nx * up_x * ny * up_y`
assert_allclose(
y,
[[[[ 0, 6, 12, 1, 7, 13, 2, 8, 14],
[18, 24, 30, 19, 25, 31, 20, 26, 32],
[ 3, 9, 15, 4, 10, 16, 5, 11, 17],
[21, 27, 33, 22, 28, 34, 23, 29, 35]],
[[36, 42, 48, 37, 43, 49, 38, 44, 50],
[54, 60, 66, 55, 61, 67, 56, 62, 68],
[39, 45, 51, 40, 46, 52, 41, 47, 53],
[57, 63, 69, 58, 64, 70, 59, 65, 71]]]]
)
@use_np
def test_pixelshuffle3d():
nchan = 1
up_x = 2
up_y = 1
up_z = 2
nx = 2
ny = 3
nz = 4
shape_before = (1, nchan * up_x * up_y * up_z, nx, ny, nz)
shape_after = (1, nchan, nx * up_x, ny * up_y, nz * up_z)
layer = nn.PixelShuffle3D((up_x, up_y, up_z))
x = mx.np.arange(_np.prod(shape_before)).reshape(shape_before)
y = layer(x)
assert y.shape == shape_after
# - Channels are reshaped to form 2x1x2 blocks
# - Within each block, the increment is `nx * ny * nz` when increasing the
# column index by 1, e.g. the block [[[ 0, 24]], [[48, 72]]]
# - Increasing the block index adds an offset of 1
assert_allclose(
y,
[[[[[ 0, 24, 1, 25, 2, 26, 3, 27],
[ 4, 28, 5, 29, 6, 30, 7, 31],
[ 8, 32, 9, 33, 10, 34, 11, 35]],
[[48, 72, 49, 73, 50, 74, 51, 75],
[52, 76, 53, 77, 54, 78, 55, 79],
[56, 80, 57, 81, 58, 82, 59, 83]],
[[12, 36, 13, 37, 14, 38, 15, 39],
[16, 40, 17, 41, 18, 42, 19, 43],
[20, 44, 21, 45, 22, 46, 23, 47]],
[[60, 84, 61, 85, 62, 86, 63, 87],
[64, 88, 65, 89, 66, 90, 67, 91],
[68, 92, 69, 93, 70, 94, 71, 95]]]]]
)
@use_np
def test_embedding():
def check_embedding():
layer = gluon.nn.Embedding(10, 100)
layer.initialize()
x = mx.np.array([3,4,2,0,1])
with mx.autograd.record():
y = layer(x)
y.backward()
assert (layer.weight.grad().asnumpy()[:5] == 1).all()
assert (layer.weight.grad().asnumpy()[5:] == 0).all()
def check_embedding_large_input():
embedding = mx.gluon.nn.Embedding(10, 1)
embedding.initialize()
embedding.hybridize()
shape = (20481,)
with mx.autograd.record():
emb_in = embedding(mx.np.ones(shape))
loss = emb_in.sum()
loss.backward()
assert embedding.weight.grad().sum().item() == 20481
check_embedding()
check_embedding_large_input()
@use_np
@pytest.mark.parametrize('dshape', [(10, ), (2, 10, 10, 10)])
def test_layernorm(dshape):
layer = nn.LayerNorm(in_channels=10)
print("checking layer {}\nshape: {}.".format(layer, dshape))
layer.initialize()
x = mx.np.ones(shape=dshape)
x.attach_grad()
with mx.autograd.record():
out = layer(x)
out.backward()
np_out = out.asnumpy()
np_dx = x.grad.asnumpy()
layer.hybridize()
x = mx.np.ones(shape=dshape)
x.attach_grad()
with mx.autograd.record():
out = layer(x)
out.backward()
mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-5, atol=1e-6)
mx.test_utils.assert_almost_equal(np_dx, x.grad.asnumpy(), rtol=1e-5, atol=1e-6)