blob: 1791241d68eeefec326a9adca30008108004aecd [file] [log] [blame]
import tvm
import numpy as np
def test_local_multi_stage():
if not tvm.module.enabled("opengl"):
return
if not tvm.module.enabled("llvm"):
return
n = tvm.var("n")
A = tvm.placeholder((n,), name='A', dtype="int32")
B = tvm.compute((n,), lambda i: A[i] + 1, name="B")
C = tvm.compute((n,), lambda i: B[i] * 2, name="C")
s = tvm.create_schedule(C.op)
s[B].opengl()
s[C].opengl()
f = tvm.build(s, [A, C], "opengl", name="multi_stage")
ctx = tvm.opengl(0)
n = 10
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
c = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx)
f(a, c)
tvm.testing.assert_allclose(c.asnumpy(), (a.asnumpy() + 1) * 2)
if __name__ == "__main__":
test_local_multi_stage()