blob: 2a8eaeb608dddfc6fcf3d8da956a2788d6398bc3 [file] [log] [blame]
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
def with_tvm(lam, *args):
""" Take numpy arrays as args, convert them to TVM tensors and call `lam`.
Result of lambda is converted back to numpy array and returned.
"""
ctx = tvm.cpu(0)
pls = [] # placeholders
vals_nd = [] # initial values
for i,arg in enumerate(args):
pls.append(tvm.placeholder(arg.shape, name='pl'+str(i)))
vals_nd.append(tvm.nd.array(arg, ctx))
out = lam(*pls)
out_nd = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out.dtype), ctx)
s = tvm.create_schedule([out.op])
m = tvm.build(s, pls + [out], "llvm")
m(*(vals_nd+[out_nd]))
return out_nd.asnumpy()
def verify_matmul(sa, sb, transp_a, transp_b):
a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32)
b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32)
c1 = np.matmul(np.transpose(a) if transp_a else a,
np.transpose(b) if transp_b else b)
c2 = with_tvm(lambda A,B: topi.matmul(A,B,transp_a,transp_b), a,b)
tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)
def test_matmul():
verify_matmul((1,1),(1,1),False,False)
verify_matmul((1,1),(1,1),True,True)
verify_matmul((2,2),(2,2),False,False)
verify_matmul((2,2),(2,2),True,True)
verify_matmul((2,3),(3,5),False,False)
verify_matmul((5,3),(3,2),False,False)
verify_matmul((3,5),(3,2),True,False)
verify_matmul((3,5),(2,3),True,True)
def verify_tensordot(sa, sb, axes):
a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32)
b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32)
c1 = np.tensordot(a, b, axes)
c2 = with_tvm(lambda A, B: topi.tensordot(A, B, axes), a, b)
tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)
def test_tensordot():
verify_tensordot((3), (3), 0)
verify_tensordot((2, 3), (3, 5), 1)
verify_tensordot((2, 2, 3), (2, 3, 5), 2)
verify_tensordot((2, 2, 3, 4), (2, 3, 4, 5), 3)
verify_tensordot((3, 2, 2), (2, 3, 5), (1, 0))
verify_tensordot((3, 2, 2), (2, 3, 5), ((1, 0), (0, 1)))
verify_tensordot((4, 3, 2, 2), (2, 4, 3, 5), ((1, 2, 0), (2, 0, 1)))
if __name__ == "__main__":
test_matmul()
test_tensordot()