blob: d44beb316770ad7ec85d4a8045e0939605afb423 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import mxnet as mx
import numpy as np
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.base import MXNetError
from mxnet.test_utils import assert_exception, default_device, set_default_device, use_np
import pytest
mx.npx.reset_np()
@pytest.mark.skipif(os.environ.get('MXNET_ENGINE_TYPE') == 'NaiveEngine',
reason="This test assumes asynchronous execution.")
def test_exc_imperative():
def imperative(exec_numpy=True):
a = mx.nd.random.normal(0, 1, (2, 2))
b = mx.nd.random.normal(0, -1, (2, 2))
c = mx.nd.dot(a, b)
if exec_numpy:
c.asnumpy()
imperative(exec_numpy=False)
pytest.raises(MXNetError, imperative, exec_numpy=True)
def test_exc_symbolic():
def symbolic(exec_backward=True, waitall=True):
x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
z = mx.sym.Variable('z')
x_shape = (2, 2)
z_shape = (3, 2)
inputs = [x, y]
out = mx.symbol.ElementWiseSum(*inputs, name="esum")
out = mx.sym.dot(z, out)
out2 = mx.sym.random.normal(0, -1, x_shape, ctx=default_device())
out = mx.sym.dot(out, out2)
out = mx.sym.make_loss(out)
arr = {'x': mx.nd.random.normal(0, 1, x_shape, ctx=default_device()),
'y': mx.nd.random.normal(0, 1, x_shape, ctx=default_device()),
'z': mx.nd.random.normal(0, 1, z_shape, ctx=default_device())}
arr_grad = {'x': mx.nd.empty(x_shape), 'y': mx.nd.empty(x_shape), 'z': mx.nd.empty(z_shape)}
exec1 = out._bind(ctx=default_device(), args=arr, args_grad=arr_grad)
outputs = exec1.forward()
if exec_backward:
exec1.backward()
if waitall:
mx.nd.waitall()
else:
exec1.grad_arrays[0].asnumpy()
else:
if waitall:
mx.nd.waitall()
else:
outputs[0].asnumpy()
pytest.raises(MXNetError, symbolic, exec_backward=False)
pytest.raises(MXNetError, symbolic, exec_backward=True)
pytest.raises(MXNetError, symbolic, exec_backward=False, waitall=True)
pytest.raises(MXNetError, symbolic, exec_backward=True, waitall=True)
def test_exc_multiple_waits():
def multiple_waits(waitall=False):
# Test calling failed op followed by wait_to_read or waitall twice
# Intention is to test rethrow for multiple wait_to_reads and waitalls
# for vars with exceptions in same scope
caught = False
try:
a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_device())
if waitall:
mx.nd.waitall()
else:
a.wait_to_read()
except MXNetError:
caught = True
assert caught, "No exception thrown, exception should be rethrown with wait_to_read/waitall"
try:
b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_device())
if waitall:
mx.nd.waitall()
else:
b.wait_to_read()
except MXNetError:
caught = True
assert caught, "No exception thrown, exception should be rethrown with wait_to_read/waitall"
multiple_waits(waitall=False)
multiple_waits(waitall=True)
@pytest.mark.skipif(os.environ.get('MXNET_ENGINE_TYPE') == 'NaiveEngine',
reason="This test assumes asynchronous execution.")
def test_exc_post_fail():
def post_fail(waitall=False):
caught = False
try:
a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_device())
if waitall:
mx.nd.waitall()
else:
a.asnumpy()
except MXNetError:
caught = True
assert caught, "No exception thrown"
b.asnumpy()
post_fail(waitall=False)
post_fail(waitall=True)
def test_exc_mutable_var_fail():
def mutable_var_check(waitall=False):
a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_device())
a = mx.nd.dot(a, a)
if waitall:
mx.nd.waitall()
else:
a.asnumpy()
pytest.raises(MXNetError, mutable_var_check, waitall=False)
pytest.raises(MXNetError, mutable_var_check, waitall=True)
def test_multiple_waitalls():
caught = False
try:
a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_device())
mx.nd.waitall()
except MXNetError:
caught = True
assert caught, "No exception thrown"
mx.nd.waitall()
def run_training_iteration(data):
output = net(data)
net = gluon.nn.HybridSequential()
net.add(gluon.nn.Dense(10))
ctx = default_device()
net.initialize(mx.init.Xavier(), ctx=ctx)
data = mx.nd.ones((3, 4))
mx.profiler.set_state("run")
run_training_iteration(data)
mx.nd.waitall()
mx.profiler.set_state("stop")
def test_opencv_exception():
def check_resize():
img = mx.nd.ones((1200, 1600, 3))
img = mx.image.imresize(img, 320, 320, interp=-1)
img.asnumpy()
pytest.raises(MXNetError, check_resize)
def test_np_reshape_exception():
a = mx.np.ones((10, 10))
a.reshape((-1,)).asnumpy() # Check no-raise
pytest.raises(MXNetError, lambda: a.reshape((1,)))
pytest.raises(MXNetError, lambda: mx.np.reshape(a, (1,)))
pytest.raises(MXNetError, lambda: mx.np.reshape(a, (-1, 3)))
@use_np
def test_np_random_incorrect_named_arguments():
random_ops = ['uniform', 'normal', 'randint', 'choice']
for op_name in random_ops:
op = getattr(mx.np.random, op_name, None)
assert op is not None
pytest.raises(TypeError, op, shape=())
pytest.raises(TypeError, op, shape=None)