| # Licensed to the Apache Software Foundation (ASF) under one |
| |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import os |
| import re |
| |
| import tvm |
| from tvm import te |
| import numpy as np |
| from tvm import topi |
| from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 |
| from tvm.contrib import utils |
| import tvm.testing |
| import pytest |
| |
| tx = te.thread_axis("threadIdx.x") |
| bx = te.thread_axis("blockIdx.x") |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_vectorize_add(): |
| num_thread = 8 |
| |
| def check_cuda(dtype, n, lanes): |
| if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version): |
| print("Skip because gpu does not have fp16 support") |
| return |
| if dtype == "int8" and not have_int8(tvm.cuda(0).compute_version): |
| print("skip because gpu does not support int8") |
| return |
| A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes)) |
| B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B") |
| s = te.create_schedule(B.op) |
| xo, xi = s[B].split(B.op.axis[0], factor=num_thread) |
| s[B].bind(xo, bx) |
| s[B].bind(xi, tx) |
| fun = tvm.build(s, [A, B], "cuda") |
| dev = tvm.cuda(0) |
| a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n, lanes))) |
| c = tvm.nd.empty((n,), B.dtype, dev) |
| fun(a, c) |
| tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) |
| |
| check_cuda("float32", 64, 2) |
| check_cuda("float32", 64, 3) |
| check_cuda("float32", 64, 4) |
| check_cuda("int8", 64, 2) |
| check_cuda("int8", 64, 3) |
| check_cuda("int8", 64, 4) |
| check_cuda("uint8", 64, 2) |
| check_cuda("uint8", 64, 3) |
| check_cuda("uint8", 64, 4) |
| check_cuda("float16", 64, 2) |
| check_cuda("float16", 64, 4) |
| check_cuda("float16", 64, 6) |
| check_cuda("float16", 64, 8) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_bf16_vectorize_add(): |
| if not have_bf16(tvm.cuda(0).compute_version): |
| print("skip because gpu does not support bf16") |
| return |
| num_thread = 8 |
| |
| def np_float2np_bf16(arr): |
| """Convert a numpy array of float to a numpy array |
| of bf16 in uint16""" |
| orig = arr.view("<u4") |
| bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF |
| return np.right_shift(orig + bias, 16).astype("uint16") |
| |
| def np_bf162np_float(arr): |
| """Convert a numpy array of bf16 (uint16) to a numpy array |
| of float""" |
| u32 = np.left_shift(arr.astype("uint32"), 16) |
| return u32.view("<f4") |
| |
| def check_cuda(n, lanes): |
| A = te.placeholder((n,), name="A", dtype="bfloat16x%d" % lanes) |
| B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B") |
| s = te.create_schedule(B.op) |
| xo, xi = s[B].split(B.op.axis[0], factor=num_thread) |
| s[B].bind(xo, bx) |
| s[B].bind(xi, tx) |
| with tvm.transform.PassContext( |
| disabled_pass=["tir.BF16Promote", "tir.BF16CastElimination", "tir.BF16TypeLowering"] |
| ): |
| fun = tvm.build(s, [A, B], "cuda") |
| dev = tvm.cuda(0) |
| np_a = np.random.uniform(size=(n, lanes)).astype("float32") |
| np_a = np_bf162np_float(np_float2np_bf16(np_a)) |
| a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np_float2np_bf16(np_a)) |
| c = tvm.nd.empty((n,), B.dtype, dev) |
| fun(a, c) |
| c = tvm.nd.empty((n, lanes), "uint16", dev).copyfrom(c) |
| tvm.testing.assert_allclose(c.numpy(), np_float2np_bf16(np_a + 1)) |
| |
| check_cuda(64, 2) |
| check_cuda(64, 4) |
| check_cuda(64, 6) |
| check_cuda(64, 8) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_multiply_add(): |
| num_thread = 8 |
| |
| def check_cuda(dtype, n, lanes): |
| if dtype == "int8" and not have_int8(tvm.cuda(0).compute_version): |
| print("skip because gpu does not support int8") |
| return |
| A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes)) |
| B = te.placeholder((n,), name="B", dtype="%sx%d" % (dtype, lanes)) |
| C = te.placeholder((n,), name="C", dtype="int32") |
| D = te.compute( |
| (n,), lambda i: tvm.tir.call_pure_extern("int32", "__dp4a", A[i], B[i], C[i]), name="D" |
| ) |
| s = te.create_schedule(D.op) |
| xo, xi = s[D].split(D.op.axis[0], factor=num_thread) |
| s[D].bind(xo, bx) |
| s[D].bind(xi, tx) |
| fun = tvm.build(s, [A, B, C, D], "cuda") |
| np_a = np.random.randint(low=-128, high=127, size=(n, lanes)) |
| np_b = np.random.randint(low=-128, high=127, size=(n, lanes)) |
| np_c = np.random.randint(low=0, high=127, size=(n,)) |
| np_d = [sum(x * y) + z for x, y, z in zip(np_a, np_b, np_c)] |
| dev = tvm.cuda(0) |
| a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np_a) |
| b = tvm.nd.empty((n,), B.dtype, dev).copyfrom(np_b) |
| c = tvm.nd.empty((n,), C.dtype, dev).copyfrom(np_c) |
| d = tvm.nd.empty((n,), D.dtype, dev) |
| fun(a, b, c, d) |
| tvm.testing.assert_allclose(d.numpy(), np_d) |
| |
| check_cuda("int8", 64, 4) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_vectorize_load(): |
| num_thread = 8 |
| |
| def check_cuda(dtype, n, lanes): |
| dev = tvm.cuda(0) |
| A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes)) |
| B = te.compute((n,), lambda i: A[i], name="B") |
| s = te.create_schedule(B.op) |
| block, thread = s[B].split(B.op.axis[0], factor=num_thread) |
| s[B].bind(block, bx) |
| s[B].bind(thread, tx) |
| fun = tvm.build(s, [A, B], "cuda", name="vector_load") |
| np_a = np.random.randint(low=-128, high=127, size=(n, lanes)) |
| a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np_a) |
| b = tvm.nd.empty((n,), B.dtype, dev) |
| fun(a, b) |
| tvm.testing.assert_allclose(a.numpy(), b.numpy()) |
| |
| check_cuda("int8", 64, 2) |
| check_cuda("int8", 64, 3) |
| check_cuda("int8", 64, 4) |
| check_cuda("int8", 64, 8) |
| check_cuda("int8", 64, 16) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_make_int8(): |
| def check_cuda(n, value, lanes): |
| dtype = "int8" |
| dev = tvm.cuda(0) |
| A = te.compute((n, lanes), lambda i, j: tvm.tir.const(value, dtype=dtype)) |
| s = te.create_schedule(A.op) |
| y, x = s[A].op.axis |
| s[A].vectorize(x) |
| s[A].bind(y, bx) |
| fun = tvm.build(s, [A], "cuda", name="make_int8x4") |
| np_a = np.full((n, lanes), value, dtype=dtype) |
| a = tvm.nd.empty(np_a.shape, dtype, dev) |
| fun(a) |
| np.testing.assert_equal(a.numpy(), np_a) |
| |
| check_cuda(64, np.int8(0xAB), 4) |
| check_cuda(64, 0, 4) |
| check_cuda(64, -3, 4) |
| check_cuda(64, np.int8(0xAB), 3) |
| check_cuda(64, 0, 3) |
| check_cuda(64, -3, 3) |
| check_cuda(64, np.int8(0xAB), 2) |
| check_cuda(64, 0, 2) |
| check_cuda(64, -3, 2) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_make_int4(): |
| def check_cuda(n, value, lanes): |
| dtype = "int4" |
| dev = tvm.cuda(0) |
| A = te.compute((n, lanes), lambda i, j: tvm.tir.const(value, dtype=dtype)) |
| s = te.create_schedule(A.op) |
| y, x = s[A].op.axis |
| s[A].vectorize(x) |
| s[A].bind(y, bx) |
| kernel_name = "make_int4x" + str(lanes) |
| fun = tvm.build(s, [A], "cuda", name=kernel_name) |
| np_a = np.full((n, lanes), value, dtype="int8") |
| a = tvm.nd.empty((n, lanes), dtype, dev) |
| fun(a) |
| np.testing.assert_equal(a.numpy(), np_a) |
| |
| check_cuda(64, 1, 4) |
| check_cuda(64, 7, 4) |
| check_cuda(64, 1, 8) |
| check_cuda(64, 7, 8) |
| check_cuda(64, 1, 16) |
| check_cuda(64, 7, 16) |
| check_cuda(64, 1, 32) |
| check_cuda(64, 7, 32) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_inf_nan(): |
| target = "cuda" |
| |
| def check_inf_nan(dev, n, value, dtype): |
| A = te.placeholder((n,), name="A", dtype=dtype) |
| inf_value = tvm.tir.const(value, dtype=dtype) |
| C = te.compute((n,), lambda i: inf_value, name="C") |
| s = te.create_schedule(C.op) |
| s[C].bind(s[C].op.axis[0], tx) |
| fun = tvm.build(s, [A, C], target) |
| a = tvm.nd.empty((n,), A.dtype, dev) |
| c = tvm.nd.empty((n,), A.dtype, dev) |
| # Only need to test compiling here |
| fun(a, c) |
| |
| dev = tvm.device(target, 0) |
| |
| check_inf_nan(dev, 1, -float("inf"), "float32") |
| check_inf_nan(dev, 1, -float("inf"), "float64") |
| check_inf_nan(dev, 1, float("inf"), "float32") |
| check_inf_nan(dev, 1, float("inf"), "float64") |
| check_inf_nan(dev, 1, float("nan"), "float32") |
| check_inf_nan(dev, 1, float("nan"), "float64") |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_shuffle(): |
| idxm = tvm.tir.indexmod |
| a = te.placeholder((64,), "int32") |
| b = te.placeholder((64,), "int32") |
| c = te.compute((64,), lambda x: a[x] + b[x - idxm(x, 4) + (3 - idxm(x, 4))]) |
| sch = te.create_schedule(c.op) |
| x = c.op.axis[0] |
| xo, xi = sch[c].split(x, 4) |
| thrx = te.thread_axis("threadIdx.x") |
| sch[c].bind(xo, thrx) |
| sch[c].vectorize(xi) |
| |
| def MyVectorize(): |
| def vectorizer(op): |
| if op.kind == tvm.tir.ForKind.VECTORIZED: |
| idx = tvm.tir.Ramp(4 * thrx.var, 1, 4) |
| store = op.body |
| value = store.value |
| new_a = tvm.tir.BufferLoad(value.a.buffer, [idx]) |
| bs, ids = [], [] |
| for i in range(4): |
| bs.append(tvm.tir.BufferLoad(value.b.buffer, [4 * thrx.var + i])) |
| ids.append(3 - i) |
| new_b = tvm.tir.Shuffle(bs, ids) |
| return tvm.tir.BufferStore(store.buffer, new_a + new_b, [idx]) |
| return None |
| |
| def _transform(f, *_): |
| return f.with_body( |
| tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ["tir.For"]) |
| ) |
| |
| return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize") |
| |
| with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, MyVectorize())]}): |
| module = tvm.build(sch, [a, b, c], target="cuda") |
| a_ = np.array(list(range(64)), dtype="int32") |
| b_ = np.array((list(range(4))[::-1]) * 16, dtype="int32") |
| c_ = np.zeros((64,), dtype="int32") |
| ref = a_ + np.array((list(range(4))) * 16, dtype="int32") |
| nda, ndb, ndc = [tvm.nd.array(i, tvm.cuda(0)) for i in [a_, b_, c_]] |
| module(nda, ndb, ndc) |
| tvm.testing.assert_allclose(ndc.numpy(), ref) |
| |
| |
| @tvm.testing.parametrize_targets("cuda", "rocm") |
| def test_crossthread_reduction1(target, dev): |
| n = te.var("n") |
| m = te.var("m") |
| A = te.placeholder((n, m), name="A") |
| k = te.reduce_axis((0, m), "m") |
| B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") |
| |
| def sched(nthd): |
| s = te.create_schedule(B.op) |
| ko, _ = s[B].split(B.op.reduce_axis[0], nparts=nthd) |
| s[B].bind(ko, te.thread_axis("threadIdx.x")) |
| s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x")) |
| func = tvm.build(s, [A, B], target) |
| return func |
| |
| def verify(nthd): |
| func = sched(nthd) |
| nn = 3 |
| # checks three typical cases |
| vals = [nthd - 1, nthd, nthd + 1] |
| for kk in [x for x in vals]: |
| size = (nn, kk) |
| a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), dev) |
| b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), dev) |
| func(a, b) |
| tvm.testing.assert_allclose(b.numpy(), np.sum(a.numpy(), axis=1), rtol=1e-3) |
| |
| verify(16) |
| verify(32) |
| verify(64) |
| |
| |
| @tvm.testing.parametrize_targets("cuda", "rocm") |
| def test_crossthread_reduction2(target, dev): |
| n = te.var("n") |
| k0 = te.var("k0") |
| k1 = te.var("k1") |
| A = te.placeholder((n, k0, k1), name="A") |
| k0 = te.reduce_axis((0, k0), "k0") |
| k1 = te.reduce_axis((0, k1), "k1") |
| B = te.compute((n,), lambda i: te.sum(A[i, k0, k1], axis=(k0, k1)), name="B") |
| |
| def sched(nthdx, nthdy): |
| s = te.create_schedule(B.op) |
| k0o, _ = s[B].split(B.op.reduce_axis[0], nparts=nthdx) |
| k1o, _ = s[B].split(B.op.reduce_axis[1], nparts=nthdy) |
| s[B].bind(k0o, te.thread_axis("threadIdx.x")) |
| s[B].bind(k1o, te.thread_axis("threadIdx.y")) |
| s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x")) |
| func = tvm.build(s, [A, B], target) |
| return func |
| |
| def verify(nthdx, nthdy): |
| func = sched(nthdx, nthdy) |
| nn = 3 |
| # checks three typical cases |
| vx = [nthdx - 1, nthdx, nthdx + 1] |
| vy = [nthdy - 1, nthdy, nthdy + 1] |
| for kk0, kk1 in [(x, y) for x in vx for y in vy]: |
| size = (nn, kk0, kk1) |
| a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), dev) |
| b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), dev) |
| func(a, b) |
| tvm.testing.assert_allclose(b.numpy(), np.sum(a.numpy(), axis=(1, 2)), rtol=1e-3) |
| |
| verify(16, 16) |
| verify(32, 32) |
| verify(16, 32) |
| verify(32, 16) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_reduction_binding(): |
| k = te.reduce_axis((0, 32), "k") |
| A = te.placeholder((96, 32), name="A") |
| B = te.compute((96,), lambda m: te.sum(A[m, k], axis=k), name="B") |
| s = te.create_schedule(B.op) |
| |
| s[B].reorder(B.op.reduce_axis[0], B.op.axis[0]) |
| |
| mo, _ = s[B].split(B.op.axis[0], 32) |
| s[B].bind(mo, te.thread_axis("blockIdx.x")) |
| |
| fcuda = tvm.build(s, [A, B], "cuda") |
| |
| |
| @tvm.testing.parametrize_targets("cuda", "rocm") |
| def test_rfactor_predicates(target, dev): |
| n = te.reduce_axis((0, 129), "n") |
| A = te.placeholder((129,), name="A") |
| B = te.compute((1,), lambda b: te.sum(A[n], axis=n), name="B") |
| |
| s = te.create_schedule(B.op) |
| |
| _, ni = s[B].split(s[B].op.reduce_axis[0], factor=8) |
| |
| BF = s.rfactor(B, ni, 0) |
| s[B].set_store_predicate(tx.var.equal(0)) |
| |
| s[B].bind(s[B].op.reduce_axis[0], tx) |
| s[B].bind(s[B].op.axis[0], bx) |
| |
| s[BF].compute_at(s[B], s[B].op.axis[0]) |
| |
| _, noi = s[BF].split(s[BF].op.reduce_axis[0], factor=2) |
| |
| BF2 = s.rfactor(BF, noi, 0) |
| |
| s[BF].bind(s[BF].op.axis[0], tx) |
| s[BF2].compute_at(s[BF], s[BF].op.axis[1]) |
| |
| fcuda = tvm.build(s, [A, B], target) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_const_float_to_half(): |
| # This import is required to use nvcc to perform code gen; |
| # otherwise it is found that the code gen is done by nvrtc. |
| from tvm import autotvm |
| |
| shape = (2, 3, 4) |
| a = te.placeholder(shape, dtype="float16", name="a") |
| b = tvm.tir.const(0.5, dtype="float16") |
| c = te.compute(shape, lambda i, j, k: a[i, j, k] > b, name="c") |
| s = te.create_schedule(c.op) |
| axes = [axis for axis in c.op.axis] |
| fused = s[c].fuse(*axes) |
| bx, tx = s[c].split(fused, factor=64) |
| s[c].bind(bx, te.thread_axis("blockIdx.x")) |
| s[c].bind(tx, te.thread_axis("threadIdx.x")) |
| |
| func = tvm.build(s, [a, c], "cuda") |
| dev = tvm.cuda(0) |
| a_np = np.random.uniform(size=shape).astype(a.dtype) |
| c_np = np.zeros(shape=shape, dtype=c.dtype) |
| a = tvm.nd.array(a_np, dev) |
| c = tvm.nd.array(c_np, dev) |
| func(a, c) |
| np.testing.assert_equal(c.numpy(), a_np > b.value) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_reduction(): |
| def check(device, dtype, m=32, n=32): |
| if not tvm.testing.device_enabled(device): |
| print("Skipping", device) |
| return |
| dev = tvm.device(device, 0) |
| a = te.placeholder((m, n), name="a", dtype=dtype) |
| b = te.placeholder((m, n), name="b", dtype=dtype) |
| c = a + b |
| d = a * b |
| e = topi.elemwise_sum([c, d]) |
| g = topi.sum(e) |
| with tvm.target.Target(device): |
| sg = topi.cuda.schedule_reduce(g) |
| func = tvm.build(sg, [a, b, g], device) |
| a_np = np.random.uniform(size=(m, n)).astype(a.dtype) |
| b_np = np.random.uniform(size=(m, n)).astype(b.dtype) |
| g_np = np.sum(np.add(a_np * b_np, a_np + b_np)) |
| a_nd = tvm.nd.array(a_np, dev) |
| b_nd = tvm.nd.array(b_np, dev) |
| g_nd = tvm.nd.array(np.zeros(g_np.shape, dtype=g_np.dtype), dev) |
| func(a_nd, b_nd, g_nd) |
| tvm.testing.assert_allclose(g_nd.numpy(), g_np, rtol=1e-3) |
| |
| check("cuda", "float32") |
| check("rocm", "float32") |
| check("cuda", "float16") |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_mix_threaded_and_normal_reduction(): |
| def check(device, dtype, m=32, n=32): |
| if not tvm.testing.device_enabled(device): |
| print("Skipping", device) |
| return |
| dev = tvm.device(device, 0) |
| if dtype == "float16" and not have_fp16(dev.compute_version): |
| print("Skip because gpu does not have fp16 support") |
| return |
| |
| a = tvm.te.placeholder((m, n), name="a", dtype=dtype) |
| b = topi.sum(a) |
| with tvm.target.Target(device): |
| sb = tvm.te.create_schedule(b.op) |
| i, _ = b.op.reduce_axis |
| sb[b].bind(i, tvm.te.thread_axis("threadIdx.x")) |
| func = tvm.build(sb, [a, b], device) |
| a_np = np.random.uniform(size=(m, n)).astype(a.dtype) |
| b_np = np.sum(a_np) |
| a_nd = tvm.nd.array(a_np, dev) |
| b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), dev) |
| func(a_nd, b_nd) |
| tvm.testing.assert_allclose(b_nd.numpy(), b_np, rtol=1e-3) |
| |
| check("cuda", "float32") |
| check("rocm", "float32") |
| check("cuda", "float16") |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_floordiv_with_vectorization(): |
| with tvm.target.cuda(): |
| # B[i] = A[floordiv(i, k)] |
| n = 256 |
| k = 37 |
| A = te.placeholder((n,), name="A") |
| B = te.compute((n,), lambda i: A[tvm.tir.floordiv(i, k)], name="B") |
| s = te.create_schedule(B.op) |
| xo, xi = s[B].split(B.op.axis[0], nparts=1) |
| xio, xii = s[B].split(xi, factor=4) |
| s[B].vectorize(xii) |
| s[B].bind(xo, bx) |
| s[B].bind(xio, tx) |
| func = tvm.build(s, [A, B], "cuda") |
| |
| dev = tvm.cuda(0) |
| a_np = np.random.uniform(size=(n,)).astype(A.dtype) |
| b_np = np.array([a_np[i // k] for i in range(0, n)]) |
| a_nd = tvm.nd.array(a_np, dev) |
| b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), dev) |
| func(a_nd, b_nd) |
| tvm.testing.assert_allclose(b_nd.numpy(), b_np, rtol=1e-3) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_floormod_with_vectorization(): |
| with tvm.target.cuda(): |
| # B[i] = A[floormod(i, k)] |
| n = 256 |
| k = 37 |
| A = te.placeholder((n,), name="A") |
| B = te.compute((n,), lambda i: A[tvm.tir.floormod(i, k)], name="B") |
| s = te.create_schedule(B.op) |
| xo, xi = s[B].split(B.op.axis[0], nparts=1) |
| xio, xii = s[B].split(xi, factor=4) |
| s[B].vectorize(xii) |
| s[B].bind(xo, bx) |
| s[B].bind(xio, tx) |
| func = tvm.build(s, [A, B], "cuda") |
| |
| dev = tvm.cuda(0) |
| a_np = np.random.uniform(size=(n,)).astype(A.dtype) |
| b_np = np.array([a_np[i % k] for i in range(0, n)]) |
| a_nd = tvm.nd.array(a_np, dev) |
| b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), dev) |
| func(a_nd, b_nd) |
| tvm.testing.assert_allclose(b_nd.numpy(), b_np, rtol=1e-3) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_vectorized_casts(): |
| def check(t0, t1, factor): |
| if (t0 == "float16" or t1 == "float16") and not have_fp16(tvm.cuda(0).compute_version): |
| print("Skip because gpu does not have fp16 support") |
| return |
| |
| # compute |
| n = 128 |
| A = te.placeholder((n,), dtype=t0, name="A") |
| B = te.placeholder((n,), dtype=t1, name="B") |
| C = te.compute((n,), lambda i: A[i] + topi.cast(B[i], A.dtype), name="C") |
| |
| # schedule |
| s = tvm.te.create_schedule(C.op) |
| ob, ib = s[C].split(s[C].op.axis[0], factor=factor) |
| s[C].vectorize(ib) |
| s[C].bind(ob, tx) |
| func = tvm.build(s, [A, B, C], "cuda") |
| |
| # correctness |
| dev = tvm.cuda(0) |
| low, high = (0, 20) if t0.startswith("u") or t1.startswith("u") else (-10, 10) |
| a_np = np.random.randint(low, high, size=n).astype(A.dtype) |
| b_np = np.random.randint(low, high, size=n).astype(B.dtype) |
| c_np = (a_np + b_np).astype(A.dtype) |
| a_nd = tvm.nd.array(a_np, dev) |
| b_nd = tvm.nd.array(b_np, dev) |
| c_nd = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np.dtype), dev) |
| func(a_nd, b_nd, c_nd) |
| tvm.testing.assert_allclose(c_nd.numpy(), c_np, rtol=1e-3) |
| |
| def skip(t0, t1): |
| if t0 == t1: |
| return True |
| # CUDA does support cast between {u}int8 and fp16. |
| skip_set = {"float16", "uint8", "int8"} |
| if t0 in skip_set and t1 in skip_set: |
| return True |
| return False |
| |
| types_4 = [ |
| "float16", |
| "float32", |
| "int8", |
| "uint8", |
| "int16", |
| "uint16", |
| "int32", |
| "uint32", |
| "float64", |
| "int64", |
| "uint64", |
| ] |
| types_8 = ["float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32"] |
| for t0, t1 in [(x, y) for x in types_4 for y in types_4 if not skip(x, y)]: |
| check(t0, t1, 4) |
| for t0, t1 in [(x, y) for x in types_8 for y in types_8 if not skip(x, y)]: |
| check(t0, t1, 8) |
| check("int8", "uint8", 16) |
| check("uint8", "int8", 16) |
| |
| |
| def sched(B): |
| s = te.create_schedule(B.op) |
| io, ii = s[B].split(s[B].op.axis[0], nparts=1) |
| iio, iii = s[B].split(ii, nparts=32) |
| _, iiii = s[B].split(iii, factor=4) |
| s[B].vectorize(iiii) |
| s[B].bind(io, bx) |
| s[B].bind(iio, tx) |
| return s |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_vectorized_intrin1(): |
| test_funcs = [ |
| (tvm.tir.floor, lambda x: np.floor(x)), |
| (tvm.tir.ceil, lambda x: np.ceil(x)), |
| (tvm.tir.trunc, lambda x: np.trunc(x)), |
| (tvm.tir.abs, lambda x: np.fabs(x)), |
| (tvm.tir.round, lambda x: np.round(x)), |
| (tvm.tir.exp, lambda x: np.exp(x)), |
| (tvm.tir.exp2, lambda x: np.exp2(x)), |
| (tvm.tir.exp10, lambda x: np.power(10, x)), |
| (tvm.tir.log, lambda x: np.log(x)), |
| (tvm.tir.log2, lambda x: np.log2(x)), |
| (tvm.tir.log10, lambda x: np.log10(x)), |
| (tvm.tir.tan, lambda x: np.tan(x)), |
| (tvm.tir.cos, lambda x: np.cos(x)), |
| (tvm.tir.cosh, lambda x: np.cosh(x)), |
| (tvm.tir.sin, lambda x: np.sin(x)), |
| (tvm.tir.sinh, lambda x: np.sinh(x)), |
| (tvm.tir.atan, lambda x: np.arctan(x)), |
| (tvm.tir.tanh, lambda x: np.tanh(x)), |
| (tvm.tir.sqrt, lambda x: np.sqrt(x)), |
| ] |
| |
| def run_test(tvm_intrin, np_func, dtype): |
| if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version): |
| print("Skip because gpu does not have fp16 support") |
| return |
| # set of intrinsics does not support fp16 yet. |
| skip_set = { |
| tvm.tir.abs, |
| tvm.tir.round, |
| tvm.tir.tan, |
| tvm.tir.atan, |
| tvm.tir.tanh, |
| tvm.tir.cosh, |
| tvm.tir.sinh, |
| } |
| if dtype == "float16" and tvm_intrin in skip_set: |
| print("Skip because '{0}' does not support fp16 yet".format(tvm_intrin.__name__)) |
| return |
| |
| n = 128 |
| A = te.placeholder((n,), dtype=dtype, name="A") |
| B = te.compute((n,), lambda *i: tvm_intrin(A(*i)), name="B") |
| s = sched(B) |
| f = tvm.build(s, [A, B], "cuda") |
| dev = tvm.cuda(0) |
| a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) |
| b = tvm.nd.array(np.zeros(shape=(n,)).astype(A.dtype), dev) |
| f(a, b) |
| tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3) |
| |
| for func in test_funcs: |
| run_test(*func, "float32") |
| run_test(*func, "float16") |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_vectorized_intrin2(dtype="float32"): |
| c2 = tvm.tir.const(2, dtype=dtype) |
| test_funcs = [ |
| (tvm.tir.power, lambda x: np.power(x, 2.0)), |
| (tvm.tir.fmod, lambda x: np.fmod(x, 2.0)), |
| ] |
| |
| def run_test(tvm_intrin, np_func): |
| n = 128 |
| A = te.placeholder((n,), dtype=dtype, name="A") |
| B = te.compute((n,), lambda i: tvm_intrin(A[i], c2), name="B") |
| s = sched(B) |
| f = tvm.build(s, [A, B], "cuda") |
| dev = tvm.cuda(0) |
| a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) |
| b = tvm.nd.array(np.zeros(shape=(n,)).astype(A.dtype), dev) |
| f(a, b) |
| tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3) |
| |
| for func in test_funcs: |
| run_test(*func) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_vectorized_popcount(): |
| def ref_popcount(x): |
| cnt = 0 |
| while x: |
| x -= x & -x |
| cnt += 1 |
| return cnt |
| |
| def run_test(dtype): |
| n = 128 |
| A = te.placeholder((n,), dtype=dtype, name="A") |
| B = te.compute((n,), lambda i: tvm.tir.popcount(A[i]), name="B") |
| s = sched(B) |
| f = tvm.build(s, [A, B], "cuda") |
| dev = tvm.cuda(0) |
| a = tvm.nd.array(np.random.randint(0, 100000, size=n).astype(A.dtype), dev) |
| b = tvm.nd.array(np.zeros(shape=(n,)).astype(B.dtype), dev) |
| f(a, b) |
| ref = np.vectorize(ref_popcount)(a.numpy()) |
| tvm.testing.assert_allclose(b.numpy(), ref) |
| |
| run_test("uint32") |
| run_test("uint64") |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_vectorize_load_permute_pad(): |
| def check_cuda(dtype, n, l, padding, lanes): |
| if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version): |
| print("Skip because gpu does not have fp16 support") |
| return |
| |
| dev = tvm.cuda(0) |
| A = tvm.te.placeholder((n, l), name="A", dtype=dtype) |
| B = tvm.te.compute( |
| (n // lanes, l + 2 * padding, lanes), |
| lambda i, j, k: tvm.te.if_then_else( |
| tvm.te.any(j < padding, j >= l + padding), |
| tvm.runtime.convert(0).astype(dtype), |
| A[i * lanes + k, j - padding], |
| ), |
| name="B", |
| ) |
| s = te.create_schedule(B.op) |
| block, thread, vectorize = s[B].op.axis |
| s[B].bind(block, bx) |
| s[B].bind(thread, tx) |
| s[B].vectorize(vectorize) |
| fun = tvm.build(s, [A, B], "cuda", name="vector_load_permute_pad") |
| np_a = np.random.randint(low=-128, high=127, size=(n, l)).astype(A.dtype) |
| a = tvm.nd.empty((n, l), A.dtype, dev).copyfrom(np_a) |
| b = tvm.nd.empty((n // lanes, l + padding * 2, lanes), B.dtype, dev) |
| fun(a, b) |
| np_a_reshape = np_a.reshape(n // lanes, lanes, l).transpose(0, 2, 1) |
| ref = np.pad( |
| np_a_reshape, ((0, 0), (padding, padding), (0, 0)), mode="constant", constant_values=0 |
| ) |
| tvm.testing.assert_allclose(b.numpy(), ref) |
| |
| check_cuda("int8", 64, 16, 3, 2) |
| check_cuda("uint8", 64, 16, 3, 2) |
| check_cuda("int8", 64, 16, 3, 4) |
| check_cuda("uint8", 64, 16, 3, 4) |
| check_cuda("int32", 64, 16, 3, 4) |
| check_cuda("float16", 64, 16, 3, 4) |
| check_cuda("float32", 64, 16, 3, 4) |
| |
| |
| def vcf_check_common(s, args): |
| N = 512 |
| |
| # To check if every vectorize loop transforms to ramp expr successfully |
| stmt = tvm.lower(s, args) |
| # Use this as a stack flag to show whether this stmt is inside a BroadcastNode |
| inside_broadcast = [False] |
| |
| # Possible patterns: |
| # Reduce init: BufferStore[Ramp] = Broadcast(0) |
| # Shared memory copy: BufferStore[Ramp] = BufferLoad[Ramp] |
| # Compute: BufferStore[Ramp] = BufferLoad[Ramp] ... Broadcast[Load] |
| |
| def pre_visit(stmt): |
| if isinstance(stmt, tvm.tir.Broadcast): |
| inside_broadcast[0] = True |
| # Check Broadcast[Imm numbers] or Broadcast[Load] patterns |
| assert isinstance(stmt.value, (tvm.tir.IntImm, tvm.tir.FloatImm, tvm.tir.BufferLoad)) |
| |
| if isinstance(stmt, (tvm.tir.BufferStore, tvm.tir.BufferLoad)): |
| is_ramp_index = isinstance(stmt.indices[-1], tvm.tir.Ramp) |
| is_vectorized_buffer = re.match(r"^.*x\d+$", stmt.buffer.dtype) |
| if isinstance(stmt, tvm.tir.BufferLoad): |
| # Check Broadcast[BufferLoad] or BufferLoad[Ramp] patterns |
| assert inside_broadcast[0] or is_ramp_index or is_vectorized_buffer |
| # Skip the rest of the BufferLoad |
| return stmt |
| else: |
| assert is_ramp_index or is_vectorized_buffer |
| |
| return None |
| |
| def post_visit(stmt): |
| if isinstance(stmt, tvm.tir.Broadcast): |
| inside_broadcast[0] = False |
| return None |
| |
| tvm.tir.stmt_functor.ir_transform(stmt["main"].body, pre_visit, post_visit) |
| |
| tgt = tvm.target.cuda() |
| mod = tvm.build(s, args, tgt) |
| # To check if every vectorize loop transforms to correct instruction |
| # print(mod.imported_modules[0].get_source()) |
| |
| dev = tvm.device("cuda", 0) |
| a = tvm.nd.array(np.random.uniform(size=(512, 512)).astype("float32"), dev) |
| b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype("float32"), dev) |
| c = tvm.nd.array(np.zeros((512, 512), dtype="float32"), dev) |
| mod(a, b, c) |
| tvm.testing.assert_allclose(c.numpy(), np.dot(a.numpy(), b.numpy()), rtol=1e-5) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_vectorized_cooperative_fetching_x(): |
| N = 512 |
| A = te.placeholder((N, N), name="A", dtype="float32") |
| B = te.placeholder((N, N), name="B", dtype="float32") |
| k = te.reduce_axis((0, N), name="k") |
| C = te.compute((N, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k)) |
| s = te.create_schedule(C.op) |
| i, j = s[C].op.axis |
| k = s[C].op.reduce_axis[0] |
| |
| AA = s.cache_read(A, "shared", [C]) |
| BB = s.cache_read(B, "shared", [C]) |
| |
| i3, i4 = s[C].split(i, factor=4) |
| i2, i3 = s[C].split(i3, factor=2) |
| i1, i2 = s[C].split(i2, factor=8) |
| i0, i1 = s[C].split(i1, factor=1) |
| j3, j4 = s[C].split(j, factor=4) |
| j2, j3 = s[C].split(j3, factor=2) |
| j1, j2 = s[C].split(j2, factor=8) |
| j0, j1 = s[C].split(j1, factor=2) |
| k1, k2 = s[C].split(k, factor=8) |
| k0, k1 = s[C].split(k1, factor=8) |
| s[C].reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3, k2, i4, j4) |
| block_it = s[C].fuse(i0, j0) |
| s[C].bind(block_it, tvm.te.thread_axis("blockIdx.x")) |
| vthread_it = s[C].fuse(i1, j1) |
| s[C].bind(vthread_it, tvm.te.thread_axis("vthread")) |
| thread_it = s[C].fuse(i2, j2) |
| s[C].bind(thread_it, tvm.te.thread_axis("threadIdx.x")) |
| s[C].vectorize(j4) |
| |
| s[AA].compute_at(s[C], k0) |
| iaa, jaa = s[AA].op.axis |
| s[BB].compute_at(s[C], k0) |
| ibb, jbb = s[BB].op.axis |
| aa_fused = s[AA].fuse(iaa, jaa) |
| bb_fused = s[BB].fuse(ibb, jbb) |
| aa1, aa2 = s[AA].split(aa_fused, factor=4) |
| aa0, aa1 = s[AA].split(aa1, factor=64) |
| bb1, bb2 = s[BB].split(bb_fused, factor=4) |
| bb0, bb1 = s[BB].split(bb1, factor=64) |
| s[AA].bind(aa1, tvm.te.thread_axis("threadIdx.x")) |
| s[AA].vectorize(aa2) |
| s[BB].bind(bb1, tvm.te.thread_axis("threadIdx.x")) |
| s[BB].vectorize(bb2) |
| |
| vcf_check_common(s, [A, B, C]) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_vectorized_cooperative_fetching_xy(): |
| N = 512 |
| A = te.placeholder((N, N), name="A") |
| B = te.placeholder((N, N), name="B") |
| k = te.reduce_axis((0, N), name="k") |
| C = te.compute((N, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k)) |
| s = te.create_schedule(C.op) |
| i, j = s[C].op.axis |
| k = s[C].op.reduce_axis[0] |
| |
| AA = s.cache_read(A, "shared", [C]) |
| BB = s.cache_read(B, "shared", [C]) |
| |
| i3, i4 = s[C].split(i, factor=4) |
| i2, i3 = s[C].split(i3, factor=2) |
| i1, i2 = s[C].split(i2, factor=8) |
| i0, i1 = s[C].split(i1, factor=1) |
| j3, j4 = s[C].split(j, factor=4) |
| j2, j3 = s[C].split(j3, factor=2) |
| j1, j2 = s[C].split(j2, factor=8) |
| j0, j1 = s[C].split(j1, factor=2) |
| k1, k2 = s[C].split(k, factor=8) |
| k0, k1 = s[C].split(k1, factor=8) |
| s[C].reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3, k2, i4, j4) |
| block_it = s[C].fuse(i0, j0) |
| s[C].bind(block_it, tvm.te.thread_axis("blockIdx.x")) |
| vthread_it = s[C].fuse(i1, j1) |
| s[C].bind(vthread_it, tvm.te.thread_axis("vthread")) |
| s[C].bind(i2, tvm.te.thread_axis("threadIdx.y")) |
| s[C].bind(j2, tvm.te.thread_axis("threadIdx.x")) |
| s[C].vectorize(j4) |
| |
| s[AA].compute_at(s[C], k0) |
| iaa, jaa = s[AA].op.axis |
| s[BB].compute_at(s[C], k0) |
| ibb, jbb = s[BB].op.axis |
| aa_fused = s[AA].fuse(iaa, jaa) |
| bb_fused = s[BB].fuse(ibb, jbb) |
| aa2, aa3 = s[AA].split(aa_fused, factor=4) |
| aa1, aa2 = s[AA].split(aa2, factor=8) |
| aa0, aa1 = s[AA].split(aa1, factor=8) |
| bb2, bb3 = s[BB].split(bb_fused, factor=4) |
| bb1, bb2 = s[BB].split(bb2, factor=8) |
| bb0, bb1 = s[BB].split(bb1, factor=8) |
| s[AA].bind(aa1, tvm.te.thread_axis("threadIdx.y")) |
| s[AA].bind(aa2, tvm.te.thread_axis("threadIdx.x")) |
| s[AA].vectorize(aa3) |
| s[BB].bind(bb1, tvm.te.thread_axis("threadIdx.y")) |
| s[BB].bind(bb2, tvm.te.thread_axis("threadIdx.x")) |
| s[BB].vectorize(bb3) |
| |
| vcf_check_common(s, [A, B, C]) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_unrolled_vectorization(): |
| dtype = "float32" |
| target = "cuda" |
| |
| # Compute declaration |
| N = 128 |
| A = te.placeholder((N, N), name="A") |
| B = te.placeholder((N, N), name="B") |
| k = te.reduce_axis((0, N), name="k") |
| C = te.compute((N, N), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C") |
| |
| # Schedule |
| s = te.create_schedule([C.op]) |
| CC = s.cache_write(C, "local") |
| i, j = s[C].op.axis |
| bx, tx, ii, ji = s[C].tile(i, j, 1, 2) |
| s[C].bind(bx, te.thread_axis("blockIdx.x")) |
| s[C].bind(tx, te.thread_axis("threadIdx.x")) |
| s[C].vectorize(ji) |
| s[CC].compute_at(s[C], tx) |
| i, j = s[CC].op.axis |
| k = s[CC].op.reduce_axis[0] |
| ko, ki = s[CC].split(k, 2) |
| s[CC].unroll(ki) |
| s[CC].vectorize(j) |
| |
| # Check correctness |
| dev = tvm.device(target) |
| a_tvm = tvm.nd.array(np.ones((N, N)).astype(dtype), device=dev) |
| b_tvm = tvm.nd.array(np.ones((N, N)).astype(dtype), device=dev) |
| c_tvm = tvm.nd.empty((N, N), device=dev) |
| func_tvm = tvm.build(s, [A, B, C], target=target) |
| func_tvm(a_tvm, b_tvm, c_tvm) |
| c_np = c_tvm.numpy() |
| tvm.testing.assert_allclose(c_np, N * np.ones((N, N))) |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_try_unaligned_vector_load(): |
| def get_compute(N, C_N, offset): |
| A = te.placeholder((N,), name="A", dtype="float16") |
| C = te.compute((C_N,), lambda i: A[i + offset], name="C") |
| return N, C_N, A, C |
| |
| def get_compute_unaligned(): |
| return get_compute(3, 2, 1) |
| |
| def get_compute_aligned(): |
| return get_compute(4, 2, 2) |
| |
| def build(A, C, N, C_N): |
| s = te.create_schedule(C.op) |
| oi, ii = s[C].split(C.op.axis[0], factor=2) |
| s[C].bind(oi, te.thread_axis("threadIdx.x")) |
| s[C].vectorize(ii) # BUG: misalignment |
| |
| tgt = tvm.target.Target(target="cuda", host="llvm") |
| dev = tvm.device(tgt.kind.name, 0) |
| f = tvm.build(s, [A, C], tgt, name="foo") |
| kernel_source = f.imported_modules[0].get_source() |
| |
| a_data = np.arange(0, N).astype(A.dtype) |
| a = tvm.nd.array(a_data, dev) |
| c = tvm.nd.array(np.zeros(C_N, dtype=C.dtype), dev) |
| f(a, c) |
| |
| return a_data, c.numpy(), kernel_source |
| |
| N, C_N, A, C = get_compute_unaligned() |
| a_data, c, kernel_source = build(A, C, N, C_N) |
| # (uint1*)(A + (1)) is invalid |
| assert "A + (1)" not in kernel_source |
| |
| expected = a_data[1 : C_N + 1] |
| assert np.allclose(c, expected), f"expected={expected}\nactual={c}" |
| |
| N, C_N, A, C = get_compute_aligned() |
| a_data, c, kernel_source = build(A, C, N, C_N) |
| # (uint1*)(A + (2)) is a valid vector load |
| assert "A + 2" in kernel_source |
| |
| expected = a_data[2 : C_N + 2] |
| assert np.allclose(c, expected), f"expected={expected}\nactual={c}" |
| |
| |
| @tvm.testing.requires_gpu |
| @tvm.testing.requires_cuda |
| def test_cuda_save_kernels_for_profiling(): |
| num_thread = 8 |
| |
| def check_cuda(n, lanes): |
| dtype = "float32" |
| A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes)) |
| B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B") |
| s = te.create_schedule(B.op) |
| xo, xi = s[B].split(B.op.axis[0], factor=num_thread) |
| s[B].bind(xo, bx) |
| s[B].bind(xi, tx) |
| tempdir = utils.tempdir() |
| tmp_path = str(tempdir.path) |
| with tvm.transform.PassContext(opt_level=3, config={"cuda.kernels_output_dir": tmp_path}): |
| _ = tvm.build(s, [A, B], "cuda") |
| assert "tvm_kernels.cu" in os.listdir(tmp_path) |
| |
| check_cuda(64, 2) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |