blob: 0841d0f54562d746e8a6b8e42a94e49676f87f02 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pytest
import tvm
import tvm.testing
from tvm import te, topi
from tvm.contrib.nvcc import have_bf16, have_fp16, have_int8
from tvm.script import ir as I
from tvm.script import tir as T
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_vectorize_add():
num_thread = 8
def check_cuda(dtype, n, lanes):
if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
if dtype == "int8" and not have_int8(tvm.cuda(0).compute_version):
print("skip because gpu does not support int8")
return
A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes))
B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B")
sch = tvm.tir.Schedule(te.create_prim_func([A, B]))
xo, xi = sch.split(sch.get_loops("B")[0], factors=[None, num_thread])
sch.bind(xo, "blockIdx.x")
sch.bind(xi, "threadIdx.x")
fun = tvm.compile(sch.mod, target="cuda")
dev = tvm.cuda(0)
a = tvm.runtime.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n, lanes)))
c = tvm.runtime.empty((n,), B.dtype, dev)
fun(a, c)
tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1)
check_cuda("float32", 64, 2)
check_cuda("float32", 64, 3)
check_cuda("float32", 64, 4)
check_cuda("int8", 64, 2)
check_cuda("int8", 64, 3)
check_cuda("int8", 64, 4)
check_cuda("uint8", 64, 2)
check_cuda("uint8", 64, 3)
check_cuda("uint8", 64, 4)
check_cuda("float16", 64, 2)
check_cuda("float16", 64, 4)
check_cuda("float16", 64, 6)
check_cuda("float16", 64, 8)
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_bf16_vectorize_add():
if not have_bf16(tvm.cuda(0).compute_version):
print("skip because gpu does not support bf16")
return
num_thread = 8
def np_float2np_bf16(arr):
"""Convert a numpy array of float to a numpy array
of bf16 in uint16"""
orig = arr.view("<u4")
bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
return np.right_shift(orig + bias, 16).astype("uint16")
def np_bf162np_float(arr):
"""Convert a numpy array of bf16 (uint16) to a numpy array
of float"""
u32 = np.left_shift(arr.astype("uint32"), 16)
return u32.view("<f4")
def check_cuda(n, lanes):
A = te.placeholder((n,), name="A", dtype="bfloat16x%d" % lanes)
B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B")
sch = tvm.tir.Schedule(te.create_prim_func([A, B]))
xo, xi = sch.split(sch.get_loops("B")[0], factors=[None, num_thread])
sch.bind(xo, "blockIdx.x")
sch.bind(xi, "threadIdx.x")
with tvm.transform.PassContext(
disabled_pass=["tir.BF16Promote", "tir.BF16CastElimination", "tir.BF16TypeLowering"]
):
fun = tvm.compile(sch.mod, target="cuda")
dev = tvm.cuda(0)
np_a = np.random.uniform(size=(n, lanes)).astype("float32")
np_a = np_bf162np_float(np_float2np_bf16(np_a))
a = tvm.runtime.empty((n,), A.dtype, dev).copyfrom(np_float2np_bf16(np_a))
c = tvm.runtime.empty((n,), B.dtype, dev)
fun(a, c)
c = tvm.runtime.empty((n, lanes), "uint16", dev).copyfrom(c)
tvm.testing.assert_allclose(c.numpy(), np_float2np_bf16(np_a + 1))
check_cuda(64, 2)
check_cuda(64, 4)
check_cuda(64, 6)
check_cuda(64, 8)
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_multiply_add():
num_thread = 8
def check_cuda(dtype, n, lanes):
if dtype == "int8" and not have_int8(tvm.cuda(0).compute_version):
print("skip because gpu does not support int8")
return
A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes))
B = te.placeholder((n,), name="B", dtype="%sx%d" % (dtype, lanes))
C = te.placeholder((n,), name="C", dtype="int32")
D = te.compute(
(n,), lambda i: tvm.tir.call_pure_extern("int32", "__dp4a", A[i], B[i], C[i]), name="D"
)
sch = tvm.tir.Schedule(te.create_prim_func([A, B, C, D]))
xo, xi = sch.split(sch.get_loops("D")[0], factors=[None, num_thread])
sch.bind(xo, "blockIdx.x")
sch.bind(xi, "threadIdx.x")
fun = tvm.compile(sch.mod, target="cuda")
np_a = np.random.randint(low=-128, high=127, size=(n, lanes))
np_b = np.random.randint(low=-128, high=127, size=(n, lanes))
np_c = np.random.randint(low=0, high=127, size=(n,))
np_d = [sum(x * y) + z for x, y, z in zip(np_a, np_b, np_c)]
dev = tvm.cuda(0)
a = tvm.runtime.empty((n,), A.dtype, dev).copyfrom(np_a)
b = tvm.runtime.empty((n,), B.dtype, dev).copyfrom(np_b)
c = tvm.runtime.empty((n,), C.dtype, dev).copyfrom(np_c)
d = tvm.runtime.empty((n,), D.dtype, dev)
fun(a, b, c, d)
tvm.testing.assert_allclose(d.numpy(), np_d)
check_cuda("int8", 64, 4)
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_vectorize_load():
num_thread = 8
def check_cuda(dtype, n, lanes):
dev = tvm.cuda(0)
A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes))
B = te.compute((n,), lambda i: A[i], name="B")
sch = tvm.tir.Schedule(te.create_prim_func([A, B]))
xo, xi = sch.split(sch.get_loops("B")[0], factors=[None, num_thread])
sch.bind(xo, "blockIdx.x")
sch.bind(xi, "threadIdx.x")
fun = tvm.compile(sch.mod, target="cuda")
np_a = np.random.randint(low=-128, high=127, size=(n, lanes))
a = tvm.runtime.empty((n,), A.dtype, dev).copyfrom(np_a)
b = tvm.runtime.empty((n,), B.dtype, dev)
fun(a, b)
tvm.testing.assert_allclose(a.numpy(), b.numpy())
check_cuda("int8", 64, 2)
check_cuda("int8", 64, 3)
check_cuda("int8", 64, 4)
check_cuda("int8", 64, 8)
check_cuda("int8", 64, 16)
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_make_int8():
def check_cuda(n, value, lanes):
dtype = "int8"
dev = tvm.cuda(0)
A = te.compute((n, lanes), lambda i, j: tvm.tir.const(value, dtype=dtype), name="A")
sch = tvm.tir.Schedule(te.create_prim_func([A]))
y, x = sch.get_loops("A")
sch.vectorize(x)
sch.bind(y, "blockIdx.x")
fun = tvm.compile(sch.mod, target="cuda")
np_a = np.full((n, lanes), value, dtype=dtype)
a = tvm.runtime.empty(np_a.shape, dtype, dev)
fun(a)
np.testing.assert_equal(a.numpy(), np_a)
check_cuda(64, np.int8(0xAB), 4)
check_cuda(64, 0, 4)
check_cuda(64, -3, 4)
check_cuda(64, np.int8(0xAB), 3)
check_cuda(64, 0, 3)
check_cuda(64, -3, 3)
check_cuda(64, np.int8(0xAB), 2)
check_cuda(64, 0, 2)
check_cuda(64, -3, 2)
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_inf_nan():
target = "cuda"
def check_inf_nan(dev, n, value, dtype):
A = te.placeholder((n,), name="A", dtype=dtype)
inf_value = tvm.tir.const(value, dtype=dtype)
C = te.compute((n,), lambda i: inf_value, name="C")
sch = tvm.tir.Schedule(te.create_prim_func([A, C]))
xo, xi = sch.split(sch.get_loops("C")[0], factors=[None, 8])
sch.bind(xo, "blockIdx.x")
sch.bind(xi, "threadIdx.x")
fun = tvm.compile(sch.mod, target="cuda")
a = tvm.runtime.empty((n,), A.dtype, dev)
c = tvm.runtime.empty((n,), A.dtype, dev)
# Only need to test compiling here
fun(a, c)
dev = tvm.device(target, 0)
check_inf_nan(dev, 1, -float("inf"), "float32")
check_inf_nan(dev, 1, -float("inf"), "float64")
check_inf_nan(dev, 1, float("inf"), "float32")
check_inf_nan(dev, 1, float("inf"), "float64")
check_inf_nan(dev, 1, float("nan"), "float32")
check_inf_nan(dev, 1, float("nan"), "float64")
@tvm.testing.parametrize_targets("cuda", "rocm")
def test_crossthread_reduction1(target, dev):
n = te.var("n")
m = te.var("m")
A = te.placeholder((n, m), name="A")
k = te.reduce_axis((0, m), "m")
B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
def sched(nthd):
sch = tvm.tir.Schedule(te.create_prim_func([A, B]))
x, k = sch.get_loops("B")
ko, _ = sch.split(k, factors=[nthd, None])
sch.bind(ko, "threadIdx.x")
sch.bind(x, "blockIdx.x")
fun = tvm.compile(sch.mod, target="cuda")
return fun
def verify(nthd):
func = sched(nthd)
nn = 3
# checks three typical cases
vals = [nthd - 1, nthd, nthd + 1]
for kk in [x for x in vals]:
size = (nn, kk)
a = tvm.runtime.tensor(np.random.uniform(size=size).astype(A.dtype), dev)
b = tvm.runtime.tensor(np.zeros(nn, dtype=B.dtype), dev)
func(a, b)
tvm.testing.assert_allclose(b.numpy(), np.sum(a.numpy(), axis=1), rtol=1e-3)
verify(16)
verify(32)
verify(64)
@tvm.testing.parametrize_targets("cuda", "rocm")
def test_crossthread_reduction2(target, dev):
n = te.var("n")
k0 = te.var("k0")
k1 = te.var("k1")
A = te.placeholder((n, k0, k1), name="A")
k0 = te.reduce_axis((0, k0), "k0")
k1 = te.reduce_axis((0, k1), "k1")
B = te.compute((n,), lambda i: te.sum(A[i, k0, k1], axis=(k0, k1)), name="B")
def sched(nthdx, nthdy):
sch = tvm.tir.Schedule(te.create_prim_func([A, B]))
x, k0, k1 = sch.get_loops("B")
k0o, _ = sch.split(k0, factors=[nthdx, None])
k1o, _ = sch.split(k1, factors=[nthdy, None])
sch.bind(k0o, "threadIdx.x")
sch.bind(k1o, "threadIdx.y")
sch.bind(x, "blockIdx.x")
func = tvm.compile(sch.mod, target="cuda")
return func
def verify(nthdx, nthdy):
func = sched(nthdx, nthdy)
nn = 3
# checks three typical cases
vx = [nthdx - 1, nthdx, nthdx + 1]
vy = [nthdy - 1, nthdy, nthdy + 1]
for kk0, kk1 in [(x, y) for x in vx for y in vy]:
size = (nn, kk0, kk1)
a = tvm.runtime.tensor(np.random.uniform(size=size).astype(A.dtype), dev)
b = tvm.runtime.tensor(np.zeros(nn, dtype=B.dtype), dev)
func(a, b)
tvm.testing.assert_allclose(b.numpy(), np.sum(a.numpy(), axis=(1, 2)), rtol=1e-3)
verify(16, 16)
verify(32, 32)
verify(16, 32)
verify(32, 16)
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_reduction_binding():
k = te.reduce_axis((0, 32), "k")
A = te.placeholder((96, 32), name="A")
B = te.compute((96,), lambda m: te.sum(A[m, k], axis=k), name="B")
sch = tvm.tir.Schedule(te.create_prim_func([A, B]))
x, k = sch.get_loops("B")
sch.reorder(k, x)
mo, _ = sch.split(x, factors=[None, 32])
sch.bind(mo, "blockIdx.x")
func = tvm.compile(sch.mod, target="cuda")
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_const_float_to_half():
# This import is required to use nvcc to perform code gen;
# otherwise it is found that the code gen is done by nvrtc.
shape = (2, 3, 4)
a = te.placeholder(shape, dtype="float16", name="a")
b = tvm.tir.const(0.5, dtype="float16")
c = te.compute(shape, lambda i, j, k: a[i, j, k] > b, name="C")
sch = tvm.tir.Schedule(te.create_prim_func([a, c]))
xo, xi = sch.split(sch.fuse(*sch.get_loops("C")), factors=[None, 64])
sch.bind(xo, "blockIdx.x")
sch.bind(xi, "threadIdx.x")
func = tvm.compile(sch.mod, target="cuda")
dev = tvm.cuda(0)
a_np = np.random.uniform(size=shape).astype(a.dtype)
c_np = np.zeros(shape=shape, dtype=c.dtype)
a = tvm.runtime.tensor(a_np, dev)
c = tvm.runtime.tensor(c_np, dev)
func(a, c)
np.testing.assert_equal(c.numpy(), a_np > b.value)
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_floordiv_with_vectorization():
with tvm.target.cuda():
# B[i] = A[floordiv(i, k)]
n = 256
k = 37
A = te.placeholder((n,), name="A")
B = te.compute((n,), lambda i: A[tvm.tir.floordiv(i, k)], name="B")
sch = tvm.tir.Schedule(te.create_prim_func([A, B]))
xo, xi = sch.split(sch.get_loops("B")[0], factors=[1, None])
xio, xii = sch.split(xi, factors=[None, 4])
sch.vectorize(xii)
sch.bind(xo, "blockIdx.x")
sch.bind(xio, "threadIdx.x")
func = tvm.compile(sch.mod, target="cuda")
dev = tvm.cuda(0)
a_np = np.random.uniform(size=(n,)).astype(A.dtype)
b_np = np.array([a_np[i // k] for i in range(0, n)])
a_nd = tvm.runtime.tensor(a_np, dev)
b_nd = tvm.runtime.tensor(np.zeros(b_np.shape, dtype=b_np.dtype), dev)
func(a_nd, b_nd)
tvm.testing.assert_allclose(b_nd.numpy(), b_np, rtol=1e-3)
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_floormod_with_vectorization():
with tvm.target.cuda():
# B[i] = A[floormod(i, k)]
n = 256
k = 37
A = te.placeholder((n,), name="A")
B = te.compute((n,), lambda i: A[tvm.tir.floormod(i, k)], name="B")
sch = tvm.tir.Schedule(te.create_prim_func([A, B]))
xo, xi = sch.split(sch.get_loops("B")[0], factors=[1, None])
xio, xii = sch.split(xi, factors=[None, 4])
sch.vectorize(xii)
sch.bind(xo, "blockIdx.x")
sch.bind(xio, "threadIdx.x")
func = tvm.compile(sch.mod, target="cuda")
dev = tvm.cuda(0)
a_np = np.random.uniform(size=(n,)).astype(A.dtype)
b_np = np.array([a_np[i % k] for i in range(0, n)])
a_nd = tvm.runtime.tensor(a_np, dev)
b_nd = tvm.runtime.tensor(np.zeros(b_np.shape, dtype=b_np.dtype), dev)
func(a_nd, b_nd)
tvm.testing.assert_allclose(b_nd.numpy(), b_np, rtol=1e-3)
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_vectorized_casts():
def check(t0, t1, factor):
if (t0 == "float16" or t1 == "float16") and not have_fp16(tvm.cuda(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
# compute
n = 128
A = te.placeholder((n,), dtype=t0, name="A")
B = te.placeholder((n,), dtype=t1, name="B")
C = te.compute((n,), lambda i: A[i] + topi.cast(B[i], A.dtype), name="C")
# schedule
sch = tvm.tir.Schedule(te.create_prim_func([A, B, C]))
ob, ib = sch.split(sch.get_loops("C")[0], factors=[None, factor])
sch.vectorize(ib)
sch.bind(ob, "threadIdx.x")
func = tvm.compile(sch.mod, target="cuda")
# correctness
dev = tvm.cuda(0)
low, high = (0, 20) if t0.startswith("u") or t1.startswith("u") else (-10, 10)
a_np = np.random.randint(low, high, size=n).astype(A.dtype)
b_np = np.random.randint(low, high, size=n).astype(B.dtype)
c_np = (a_np + b_np).astype(A.dtype)
a_nd = tvm.runtime.tensor(a_np, dev)
b_nd = tvm.runtime.tensor(b_np, dev)
c_nd = tvm.runtime.tensor(np.zeros(c_np.shape, dtype=c_np.dtype), dev)
func(a_nd, b_nd, c_nd)
tvm.testing.assert_allclose(c_nd.numpy(), c_np, rtol=1e-3)
def skip(t0, t1):
if t0 == t1:
return True
# CUDA does support cast between {u}int8 and fp16.
skip_set = {"float16", "uint8", "int8"}
if t0 in skip_set and t1 in skip_set:
return True
return False
types_4 = [
"float16",
"float32",
"int8",
"uint8",
"int16",
"uint16",
"int32",
"uint32",
"float64",
"int64",
"uint64",
]
types_8 = ["float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32"]
for t0, t1 in [(x, y) for x in types_4 for y in types_4 if not skip(x, y)]:
check(t0, t1, 4)
for t0, t1 in [(x, y) for x in types_8 for y in types_8 if not skip(x, y)]:
check(t0, t1, 8)
check("int8", "uint8", 16)
check("uint8", "int8", 16)
def sched(A, B):
# schedule
sch = tvm.tir.Schedule(te.create_prim_func([A, B]))
io, ii = sch.split(sch.get_loops("B")[0], factors=[1, None])
iio, iii = sch.split(ii, factors=[32, None])
_, iiii = sch.split(iii, factors=[None, 4])
sch.vectorize(iiii)
sch.bind(io, "blockIdx.x")
sch.bind(iio, "threadIdx.x")
return tvm.compile(sch.mod, target="cuda")
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_vectorized_intrin1():
test_funcs = [
(tvm.tir.floor, lambda x: np.floor(x)),
(tvm.tir.ceil, lambda x: np.ceil(x)),
(tvm.tir.trunc, lambda x: np.trunc(x)),
(tvm.tir.abs, lambda x: np.fabs(x)),
(tvm.tir.round, lambda x: np.round(x)),
(tvm.tir.exp, lambda x: np.exp(x)),
(tvm.tir.exp2, lambda x: np.exp2(x)),
(tvm.tir.exp10, lambda x: np.power(10, x)),
(tvm.tir.log, lambda x: np.log(x)),
(tvm.tir.log2, lambda x: np.log2(x)),
(tvm.tir.log10, lambda x: np.log10(x)),
(tvm.tir.tan, lambda x: np.tan(x)),
(tvm.tir.cos, lambda x: np.cos(x)),
(tvm.tir.cosh, lambda x: np.cosh(x)),
(tvm.tir.sin, lambda x: np.sin(x)),
(tvm.tir.sinh, lambda x: np.sinh(x)),
(tvm.tir.atan, lambda x: np.arctan(x)),
(tvm.tir.tanh, lambda x: np.tanh(x)),
(tvm.tir.sqrt, lambda x: np.sqrt(x)),
]
def run_test(tvm_intrin, np_func, dtype):
if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
# set of intrinsics does not support fp16 yet.
skip_set = {
tvm.tir.abs,
tvm.tir.round,
tvm.tir.tan,
tvm.tir.atan,
tvm.tir.tanh,
tvm.tir.cosh,
tvm.tir.sinh,
}
if dtype == "float16" and tvm_intrin in skip_set:
print("Skip because '{0}' does not support fp16 yet".format(tvm_intrin.__name__))
return
n = 128
A = te.placeholder((n,), dtype=dtype, name="A")
B = te.compute((n,), lambda *i: tvm_intrin(A(*i)), name="B")
f = sched(A, B)
dev = tvm.cuda(0)
a = tvm.runtime.tensor(np.random.uniform(0, 1, size=n).astype(A.dtype), dev)
b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(A.dtype), dev)
f(a, b)
tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3)
for func in test_funcs:
run_test(*func, "float32")
run_test(*func, "float16")
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_vectorized_intrin2(dtype="float32"):
c2 = tvm.tir.const(2, dtype=dtype)
test_funcs = [
(tvm.tir.power, lambda x: np.power(x, 2.0)),
(tvm.tir.fmod, lambda x: np.fmod(x, 2.0)),
]
def run_test(tvm_intrin, np_func):
n = 128
A = te.placeholder((n,), dtype=dtype, name="A")
B = te.compute((n,), lambda i: tvm_intrin(A[i], c2), name="B")
f = sched(A, B)
dev = tvm.cuda(0)
a = tvm.runtime.tensor(np.random.uniform(0, 1, size=n).astype(A.dtype), dev)
b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(A.dtype), dev)
f(a, b)
tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3)
for func in test_funcs:
run_test(*func)
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_vectorized_popcount():
def ref_popcount(x):
cnt = 0
while x:
x -= x & -x
cnt += 1
return cnt
def run_test(dtype):
n = 128
A = te.placeholder((n,), dtype=dtype, name="A")
B = te.compute((n,), lambda i: tvm.tir.popcount(A[i]), name="B")
f = sched(A, B)
dev = tvm.cuda(0)
a = tvm.runtime.tensor(np.random.randint(0, 100000, size=n).astype(A.dtype), dev)
b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(B.dtype), dev)
f(a, b)
ref = np.vectorize(ref_popcount)(a.numpy())
tvm.testing.assert_allclose(b.numpy(), ref)
run_test("uint32")
run_test("uint64")
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_vectorize_load_permute_pad():
def check_cuda(dtype, n, l, padding, lanes):
if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
dev = tvm.cuda(0)
A = tvm.te.placeholder((n, l), name="A", dtype=dtype)
B = tvm.te.compute(
(n // lanes, l + 2 * padding, lanes),
lambda i, j, k: tvm.te.if_then_else(
tvm.te.any(j < padding, j >= l + padding),
tvm.tir.const(0, dtype),
A[i * lanes + k, j - padding],
),
name="B",
)
sch = tvm.tir.Schedule(te.create_prim_func([A, B]))
block, thread, vectorize = sch.get_loops("B")
sch.bind(block, "blockIdx.x")
sch.bind(thread, "threadIdx.x")
sch.vectorize(vectorize)
fun = tvm.compile(sch.mod, target="cuda")
np_a = np.random.randint(low=-128, high=127, size=(n, l)).astype(A.dtype)
a = tvm.runtime.empty((n, l), A.dtype, dev).copyfrom(np_a)
b = tvm.runtime.empty((n // lanes, l + padding * 2, lanes), B.dtype, dev)
fun(a, b)
np_a_reshape = np_a.reshape(n // lanes, lanes, l).transpose(0, 2, 1)
ref = np.pad(
np_a_reshape, ((0, 0), (padding, padding), (0, 0)), mode="constant", constant_values=0
)
tvm.testing.assert_allclose(b.numpy(), ref)
check_cuda("int8", 64, 16, 3, 2)
check_cuda("uint8", 64, 16, 3, 2)
check_cuda("int8", 64, 16, 3, 4)
check_cuda("uint8", 64, 16, 3, 4)
check_cuda("int32", 64, 16, 3, 4)
check_cuda("float16", 64, 16, 3, 4)
check_cuda("float32", 64, 16, 3, 4)
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_try_unaligned_vector_load():
def get_compute(N, C_N, offset):
A = te.placeholder((N,), name="A", dtype="float16")
C = te.compute((C_N,), lambda i: A[i + offset], name="C")
return N, C_N, A, C
def get_compute_unaligned():
return get_compute(3, 2, 1)
def get_compute_aligned():
return get_compute(4, 2, 2)
def build(A, C, N, C_N):
sch = tvm.tir.Schedule(te.create_prim_func([A, C]))
oi, ii = sch.split(sch.get_loops("C")[0], factors=[None, 2])
sch.bind(oi, "threadIdx.x")
sch.vectorize(ii) # BUG: misalignment
f = tvm.tir.build(sch.mod, target="cuda")
kernel_source = f.imports[0].inspect_source()
dev = tvm.cuda()
a_data = np.arange(0, N).astype(A.dtype)
a = tvm.runtime.tensor(a_data, dev)
c = tvm.runtime.tensor(np.zeros(C_N, dtype=C.dtype), dev)
f(a, c)
return a_data, c.numpy(), kernel_source
N, C_N, A, C = get_compute_unaligned()
a_data, c, kernel_source = build(A, C, N, C_N)
# (uint1*)(A + (1)) is invalid
assert "A + (1)" not in kernel_source
expected = a_data[1 : C_N + 1]
assert np.allclose(c, expected), f"expected={expected}\nactual={c}"
N, C_N, A, C = get_compute_aligned()
a_data, c, kernel_source = build(A, C, N, C_N)
# (uint1*)(A + (2)) is a valid vector load
assert "A + 2" in kernel_source
expected = a_data[2 : C_N + 2]
assert np.allclose(c, expected), f"expected={expected}\nactual={c}"
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_thread_sync_inside_condition():
@T.prim_func
def func1(A: T.Buffer((4, 4), "float32")) -> None:
A_shared = T.alloc_buffer((4, 4), "float32", scope="shared")
for bx in T.thread_binding(1, "blockIdx.x"):
for tx in T.thread_binding(32, "threadIdx.x"):
if A[0, 0] > 1.0:
for i, j in T.grid(4, 4):
A_shared[i, j] = A[i, j]
for i, j in T.grid(4, 4):
A[i, j] = A_shared[i, j] + 1.0
@T.prim_func
def func2(A: T.Buffer((4, 4), "float32")) -> None:
A_shared = T.alloc_buffer((4, 4), "float32", scope="shared")
for bx in T.thread_binding(1, "blockIdx.x"):
for tx in T.thread_binding(32, "threadIdx.x"):
if T.tvm_thread_invariant(A[0, 0] > 1.0):
for i, j in T.grid(4, 4):
A_shared[i, j] = A[i, j]
for i, j in T.grid(4, 4):
A[i, j] = A_shared[i, j] + 1.0
@T.prim_func
def func3(A: T.Buffer((4, 4), "float32")) -> None:
A_shared = T.alloc_buffer((4, 4), "float32", scope="shared")
for bx in T.thread_binding(1, "blockIdx.x"):
for tx in T.thread_binding(32, "threadIdx.x"):
while T.tvm_thread_invariant(A[0, 0] > 1.0):
for i, j in T.grid(4, 4):
A_shared[i, j] = A[i, j]
for i, j in T.grid(4, 4):
A[i, j] = A_shared[i, j] + 1.0
mod = tvm.IRModule({"main": func1})
with pytest.raises(tvm.error.InternalError):
tvm.compile(mod, target="cuda")
mod = tvm.IRModule({"main": func2})
tvm.compile(mod, target="cuda")
mod = tvm.IRModule({"main": func3})
tvm.compile(mod, target="cuda")
@tvm.testing.requires_cuda
def test_invalid_reinterpret():
@T.prim_func
def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None:
for tx in T.thread_binding(4, "threadIdx.x"):
B[tx] = T.call_intrin("uint8", "tir.reinterpret", A[tx])
with pytest.raises(tvm.error.TVMError):
tvm.compile(func, target="cuda")
@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(9)
def test_cuda_tensormap():
# fmt: off
@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
A_map: T.handle("tensormap") = T.tvm_stack_alloca("tensormap", 1)
T.call_packed("runtime.cuTensorMapInit", A_map, "float32", 2, A.data,
16, 16, 64, 16, 16, 1, 1, 0, 0, 0, 0)
for blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for threadIdx in T.thread_binding(128, thread="threadIdx.x"):
if threadIdx == 0:
A[0, 0] = T.reinterpret("float64", A_map)
# fmt: on
mod = tvm.IRModule({"main": main})
mod = tvm.compile(mod, target="cuda")
assert (
"""
extern "C" __global__ void __launch_bounds__(128) main_kernel(float* __restrict__ A, const __grid_constant__ CUtensorMap A_map) {
if (((int)threadIdx.x) == 0) {
A[0] = ((float)(*(double *)(&(A_map))));
}
}""".strip()
in mod.mod.imports[0].inspect_source()
)
@tvm.testing.requires_cuda
def test_cuda_device_func_call():
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(a: T.float32, b: T.float32) -> T.float32:
return a + b
@T.prim_func
def main(
A: T.Buffer((1024, 1024), "float32"),
B: T.Buffer((1024, 1024), "float32"),
C: T.Buffer((1024, 1024), "float32"),
):
for bx in T.thread_binding(1024, "blockIdx.x"):
for tx in T.thread_binding(1024, "threadIdx.x"):
C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])
lib = tvm.compile(Module, target="cuda")
cuda_code = lib.mod.imports[0].inspect_source()
assert 'extern "C" __device__ float add(float a, float b) {\n return (a + b);\n}' in cuda_code
@tvm.testing.requires_cuda
def test_cuda_float_const_hex_format():
"""Test that float constants are emitted in hexadecimal format for precision"""
@I.ir_module
class Module:
@T.prim_func
def main(
A: T.Buffer((1024, 1024), "float32"),
):
for bx in T.thread_binding(1024, "blockIdx.x"):
for tx in T.thread_binding(1024, "threadIdx.x"):
A[bx, tx] = T.float32(1 / 27)
lib = tvm.compile(Module, target="cuda")
cuda_code = lib.mod.imports[0].inspect_source()
assert "0x1.2f684bda12f68p-5f" in cuda_code
@tvm.testing.requires_cuda
def test_device_host_call_same_func():
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(a: T.int32, b: T.int32) -> T.int32:
return a + b
@T.prim_func
def main(
A: T.Buffer((128, 128), "int32"),
B: T.Buffer((128, 128), "int32"),
C: T.Buffer((128, 128), "int32"),
):
length: T.int32 = Module.add(64, 64) # Call from host
for bx in T.thread_binding(length, "blockIdx.x"):
for tx in T.thread_binding(length, "threadIdx.x"):
C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) # Call from device
# 1. If we set host to llvm, it will raise an error of
# "the tir.ret should be transformed to return zero before the llvm code generation."
# Need to revisit this.
# 2. We set a dummy mcpu value for testing purpose,
# in order to avoid checking a function is host or device based on the "cpu" substring.
target = tvm.target.Target({"kind": "cuda", "mcpu": "dummy_mcpu"}, host="c")
lib = tvm.compile(Module, target=target)
cuda_code = lib.mod.imports[0].inspect_source()
assert 'extern "C" __device__ int add(int a, int b) {\n return (a + b);\n}' in cuda_code
# Run a simple test
dev = tvm.cuda(0)
a_np = np.random.randint(0, 10, (128, 128), dtype="int32")
b_np = np.random.randint(0, 10, (128, 128), dtype="int32")
a_tvm = tvm.runtime.tensor(a_np, device=dev)
b_tvm = tvm.runtime.tensor(b_np, device=dev)
c_tvm = tvm.runtime.empty((128, 128), dtype="int32", device=dev)
lib["main"](a_tvm, b_tvm, c_tvm)
tvm.testing.assert_allclose(c_tvm.numpy(), a_np + b_np)
@tvm.testing.requires_cuda
def test_thread_return():
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
for bx in T.thread_binding(32, "blockIdx.x"):
for tx in T.thread_binding(32, "threadIdx.x"):
if bx >= 16 or tx >= 16:
T.thread_return()
B[bx, tx] = A[bx, tx]
lib = tvm.compile(Module, target="cuda")
cuda_code = lib.mod.imports[0].inspect_source()
assert "return;" in cuda_code
if __name__ == "__main__":
tvm.testing.main()