blob: 5b3ba952aafbc42403b89d17c0d7be61bd5573a0 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import threading
import numpy as np
import mxnet as mx
from mxnet import device, attribute
from mxnet.device import Device
from mxnet.attribute import AttrScope
from mxnet.test_utils import assert_almost_equal, set_default_device
from mxnet.util import _NumpyArrayScope, set_np_shape
def test_device():
device_list = []
device_list.append(device.current_device())
def f():
set_default_device(mx.gpu(11))
device_list.append(device.current_device())
thread = threading.Thread(target=f)
thread.start()
thread.join()
assert Device.devtype2str[device_list[0].device_typeid] == "cpu"
assert device_list[0].device_id == 0
assert Device.devtype2str[device_list[1].device_typeid] == "gpu"
assert device_list[1].device_id == 11
e1 = threading.Event()
e2 = threading.Event()
status = [False]
def g():
with mx.cpu(10):
e2.set()
e1.wait()
if device.current_device().device_id == 10:
status[0] = True
thread = threading.Thread(target=g)
thread.start()
e2.wait()
with Device("cpu", 11):
e1.set()
thread.join()
e1.clear()
e2.clear()
assert status[0], "Spawned thread didn't set the correct device"
def test_attrscope():
attrscope_list = []
with AttrScope(y="hi", z="hey") as attrscope:
attrscope_list.append(attrscope)
def f():
with AttrScope(x="hello") as attrscope:
attrscope_list.append(attrscope)
thread = threading.Thread(target=f)
thread.start()
thread.join()
assert len(attrscope_list[0]._attr) == 2
assert attrscope_list[1]._attr["x"] == "hello"
e1 = threading.Event()
e2 = threading.Event()
status = [False]
def g():
with mx.AttrScope(x="hello"):
e2.set()
e1.wait()
if "hello" in mx.attribute.current()._attr.values():
status[0] = True
thread = threading.Thread(target=g)
thread.start()
e2.wait()
with AttrScope(x="hi"):
e1.set()
thread.join()
e1.clear()
e2.clear()
assert status[0], "Spawned thread didn't set the correct attr key values"
def test_name():
name_list = []
name_manager = mx.name.current()
name_manager.get(None, "main_thread")
name_list.append(name_manager)
def f():
with mx.name.NameManager():
name_manager = mx.name.current()
name_manager.get(None, "spawned_thread")
name_list.append(name_manager)
thread = threading.Thread(target=f)
thread.start()
thread.join()
assert "main_thread" in name_list[0]._counter, "cannot find the string `main thread` in name_list[0]._counter"
assert "spawned_thread" in name_list[1]._counter, "cannot find the string `spawned thread` in name_list[1]._counter"
e1 = threading.Event()
e2 = threading.Event()
status = [False]
def g():
with mx.name.NameManager():
e2.set()
e1.wait()
if "main_thread" not in mx.name.current()._counter:
status[0] = True
thread = threading.Thread(target=g)
thread.start()
e2.wait()
with mx.name.NameManager():
mx.name.current().get(None, "main_thread")
e1.set()
thread.join()
e1.clear()
e2.clear()
assert status[0], "Spawned thread isn't using thread local NameManager"
def test_blockscope():
class dummy_block:
pass
blockscope_list = []
status = [False]
event = threading.Event()
def f():
net = dummy_block() # BlockScope only keeps a weakref to the Block
with mx.gluon.block._block_scope(net):
x = mx.name.current().get(None, "hello")
event.wait()
if x == "dummy_block_hello0":
status[0] = True
thread = threading.Thread(target=f)
thread.start()
event.set()
thread.join()
event.clear()
assert status[0], "Spawned thread isn't using the correct blockscope namemanager"
def test_createblock():
status = [False]
def f():
net = mx.gluon.nn.Dense(2)
net.initialize()
x = net(mx.np.array([1, 2, 3]))
x.wait_to_read()
status[0] = True
thread = threading.Thread(target=f)
thread.start()
thread.join()
assert status[0], "Failed to create a layer within a thread"
def test_symbol():
status = [False]
def f():
a = mx.sym.var("a")
b = mx.sym.var("b")
a_ = mx.nd.ones((2, 2))
c_ = a_.copy()
func1 = (a + b)._bind(mx.cpu(), args={'a': a_, 'b': c_})
func1.forward()[0].wait_to_read()
status[0] = True
thread = threading.Thread(target=f)
thread.start()
thread.join()
assert status[0], "Failed to execute a symbolic graph within a thread"
def test_np_array_scope():
np_array_scope_list = []
_NumpyArrayScope._current = _NumpyArrayScope(False)
np_array_scope_list.append(_NumpyArrayScope._current)
def f():
_NumpyArrayScope._current = _NumpyArrayScope(True)
np_array_scope_list.append(_NumpyArrayScope._current)
thread = threading.Thread(target=f)
thread.start()
thread.join()
assert len(np_array_scope_list) == 2
assert not np_array_scope_list[0]._is_np_array
assert np_array_scope_list[1]._is_np_array
event = threading.Event()
status = [False]
def g():
with mx.np_array(False):
event.wait()
if not mx.is_np_array():
status[0] = True
thread = threading.Thread(target=g)
thread.start()
_NumpyArrayScope._current = _NumpyArrayScope(True)
event.set()
thread.join()
event.clear()
assert status[0], "Spawned thread didn't set status correctly"
def test_np_global_shape():
prev_np_shape = set_np_shape(2)
data = []
def f():
# scalar
data.append(mx.np.ones(shape=()))
# zero-dim
data.append(mx.np.ones(shape=(0, 1, 2)))
try:
thread = threading.Thread(target=f)
thread.start()
thread.join()
assert_almost_equal(data[0].asnumpy(), np.ones(shape=()))
assert_almost_equal(data[1].asnumpy(), np.ones(shape=(0, 1, 2)))
finally:
set_np_shape(prev_np_shape)