| import numpy as np |
| import tvm |
| import topi |
| from topi.util import get_const_tuple |
| |
| def with_tvm(lam, *args): |
| """ Take numpy arrays as args, convert them to TVM tensors and call `lam`. |
| Result of lambda is converted back to numpy array and returned. |
| """ |
| ctx = tvm.cpu(0) |
| pls = [] # placeholders |
| vals_nd = [] # initial values |
| for i,arg in enumerate(args): |
| pls.append(tvm.placeholder(arg.shape, name='pl'+str(i))) |
| vals_nd.append(tvm.nd.array(arg, ctx)) |
| |
| out = lam(*pls) |
| out_nd = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out.dtype), ctx) |
| s = tvm.create_schedule([out.op]) |
| m = tvm.build(s, pls + [out], "llvm") |
| m(*(vals_nd+[out_nd])) |
| return out_nd.asnumpy() |
| |
| def verify_matmul(sa, sb, transp_a, transp_b): |
| a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32) |
| b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32) |
| c1 = np.matmul(np.transpose(a) if transp_a else a, |
| np.transpose(b) if transp_b else b) |
| c2 = with_tvm(lambda A,B: topi.matmul(A,B,transp_a,transp_b), a,b) |
| tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5) |
| |
| def test_matmul(): |
| verify_matmul((1,1),(1,1),False,False) |
| verify_matmul((1,1),(1,1),True,True) |
| verify_matmul((2,2),(2,2),False,False) |
| verify_matmul((2,2),(2,2),True,True) |
| verify_matmul((2,3),(3,5),False,False) |
| verify_matmul((5,3),(3,2),False,False) |
| verify_matmul((3,5),(3,2),True,False) |
| verify_matmul((3,5),(2,3),True,True) |
| |
| def verify_tensordot(sa, sb, axes): |
| a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32) |
| b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32) |
| c1 = np.tensordot(a, b, axes) |
| c2 = with_tvm(lambda A, B: topi.tensordot(A, B, axes), a, b) |
| tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5) |
| |
| def test_tensordot(): |
| verify_tensordot((3), (3), 0) |
| verify_tensordot((2, 3), (3, 5), 1) |
| verify_tensordot((2, 2, 3), (2, 3, 5), 2) |
| verify_tensordot((2, 2, 3, 4), (2, 3, 4, 5), 3) |
| verify_tensordot((3, 2, 2), (2, 3, 5), (1, 0)) |
| verify_tensordot((3, 2, 2), (2, 3, 5), ((1, 0), (0, 1))) |
| verify_tensordot((4, 3, 2, 2), (2, 4, 3, 5), ((1, 2, 0), (2, 0, 1))) |
| |
| if __name__ == "__main__": |
| test_matmul() |
| test_tensordot() |
| |