blob: 3be938852fdfefe73f7a1d40035ce09e3bab7838 [file] [log] [blame]
"""Test code for broadcasting operators."""
from common import get_all_backend
import numpy as np
import tvm
import topi
def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
# Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A")
B = fbcast(A, out_shape)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B)
foo = tvm.build(s, [A, B], device, name="broadcast_to")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.broadcast_to(data_npy, out_shape)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
foo(data_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for target in get_all_backend():
check_device(target)
check_device("sdaccel")
def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
ftopi, fnumpy,
lhs_min=-100, lhs_max=100,
rhs_min=-100, rhs_max=100,
dtype="float32"):
# Build the logic and compile the function
A = (tvm.var("A", dtype=dtype) if lhs_shape is None
else tvm.placeholder(shape=lhs_shape, name="A", dtype=dtype))
B = (tvm.var("B", dtype=dtype) if rhs_shape is None
else tvm.placeholder(shape=rhs_shape, name="B", dtype=dtype))
C = ftopi(A, B)
if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr):
assert(isinstance(C, tvm.expr.Expr))
return
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(C)
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + ftopi.__name__)
if lhs_shape is None:
lhs_npy = float(np.random.uniform(low=lhs_min, high=lhs_max))
if dtype.startswith('int'):
lhs_npy = int(lhs_npy)
lhs_nd = lhs_npy
else:
lhs_npy = np.random.uniform(low=lhs_min, high=lhs_max,
size=lhs_shape).astype(A.dtype)
lhs_nd = tvm.nd.array(lhs_npy, ctx)
if rhs_shape is None:
rhs_npy = float(np.random.uniform(low=rhs_min, high=rhs_max))
if dtype.startswith('int'):
rhs_npy = int(rhs_npy)
rhs_nd = rhs_npy
else:
rhs_npy = np.random.uniform(low=rhs_min, high=rhs_max,
size=rhs_shape).astype(A.dtype)
rhs_nd = tvm.nd.array(rhs_npy, ctx)
out_npy = fnumpy(lhs_npy, rhs_npy)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
foo(lhs_nd, rhs_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
for target in get_all_backend():
check_device(target)
check_device("sdaccel")
def test_broadcast_to():
verify_broadcast_to_ele((1,), (10,), topi.broadcast_to)
verify_broadcast_to_ele((), (10,), topi.broadcast_to)
verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4), topi.broadcast_to)
verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32), topi.broadcast_to)
def test_add():
verify_broadcast_binary_ele(
(), (), topi.add, np.add)
verify_broadcast_binary_ele(
(5, 2, 3), (2, 1), topi.add, np.add)
def test_subtract():
verify_broadcast_binary_ele(
(5, 2, 3), (), topi.subtract, np.subtract)
verify_broadcast_binary_ele(
(5, 2, 3), None, topi.subtract, np.subtract)
verify_broadcast_binary_ele(
None, None, topi.subtract, np.subtract)
verify_broadcast_binary_ele(
(1, 32), (64, 32), topi.subtract, np.subtract)
def test_multiply():
verify_broadcast_binary_ele(
(5, 64, 128), (2, 5, 64, 1), topi.multiply, np.multiply)
def test_divide():
verify_broadcast_binary_ele(
None, (10,), topi.divide, np.divide, rhs_min=0.0001)
verify_broadcast_binary_ele(
(), None, topi.divide, np.divide, rhs_min=0.0001)
verify_broadcast_binary_ele(
(2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)
def test_maximum_minmum():
verify_broadcast_binary_ele(
(32,), (64, 32), topi.maximum, np.maximum)
verify_broadcast_binary_ele(
(1, 2, 2, 1, 32), (64, 32), topi.minimum, np.minimum)
def test_power():
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.power, np.power, lhs_min=0.001, rhs_min=0.001, rhs_max=2)
def test_mod():
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32")
def test_cmp():
# explicit specify the output type
def greater(x, y):
return topi.greater(x, y).astype("int8")
def less(x, y):
return topi.less(x, y).astype("int8")
def equal(x, y):
return topi.equal(x, y).astype("int8")
def not_equal(x, y):
return topi.not_equal(x, y).astype("int8")
def greater_equal(x, y):
return topi.greater_equal(x, y).astype("int8")
def less_equal(x, y):
return topi.less_equal(x, y).astype("int8")
verify_broadcast_binary_ele(
(1, 2, 2), (2,), greater, np.greater)
verify_broadcast_binary_ele(
(2, 1, 2), (2, 3, 1), less, np.less)
verify_broadcast_binary_ele(
(2, 1, 2), (2, 3, 1), equal, np.equal,
lhs_min=-2, lhs_max=2, rhs_min=-2, rhs_max=2, dtype='int32')
verify_broadcast_binary_ele(
(2, 1, 2), (2, 3, 1), not_equal, np.not_equal,
lhs_min=-2, lhs_max=2, rhs_min=-2, rhs_max=2, dtype='int32')
verify_broadcast_binary_ele(
(7, 1, 5), (7, 3, 1), greater_equal, np.greater_equal,
lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
verify_broadcast_binary_ele(
(7, 1, 5), (7, 3, 1), less_equal, np.less_equal,
lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
def test_shift():
# explicit specify the output type
verify_broadcast_binary_ele(
(2, 1, 2), None, topi.right_shift, np.right_shift,
dtype="int32", rhs_min=0, rhs_max=32)
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.left_shift, np.left_shift,
dtype="int32", rhs_min=0, rhs_max=32)
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.left_shift, np.left_shift,
dtype="int8", rhs_min=0, rhs_max=32)
if __name__ == "__main__":
test_add()
test_shift()
test_cmp()
test_mod()
test_subtract()
test_multiply()
test_divide()
test_maximum_minmum()
test_power()
test_broadcast_to()