blob: 826a6e4e5e4a739b6df1e2f054cfb2b33a5cae08 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring
import numpy as np
import pytest
import torch
import tvm
import tvm.testing
from tvm.script import tirx as Tx
DEV = tvm.device("cuda")
def _get_source(func: tvm.tirx.PrimFunc) -> str:
target = tvm.target.Target("cuda")
mod = tvm.IRModule({"main": func})
mod = tvm.compile(mod, target=target, tir_pipeline="tirx")
src = mod.mod.imports[0].inspect_source()
return src, mod
def _helper_source(src: str, helper_name: str) -> str:
start = src.index(helper_name)
next_helper = src.find("__device__", start + len(helper_name))
if next_helper == -1:
return src[start:]
return src[start:next_helper]
def test_serial_pragma_unroll_codegen():
@Tx.prim_func
def main(A: Tx.Buffer((4,), "int32")):
with Tx.kernel():
tx = Tx.thread_id([32])
if Tx.filter(tx, tx == 0):
with Tx.thread():
for i in Tx.serial(4, unroll=True):
if i == 2:
break
A[i] = A[i] + 1
src, _ = _get_source(main)
assert "#pragma unroll\n" in src
assert "for (" in src
assert "break;" in src
def test_cluster_cta_id_codegen_uses_coordinate_sregs():
@Tx.prim_func
def main(A: Tx.Buffer((1,), "int32")):
with Tx.kernel():
cbx, cby = Tx.cta_id_in_cluster([2, 2])
tx = Tx.thread_id([32])
if Tx.filter(tx, tx == 0):
with Tx.thread():
A[0] = cbx + cby
src, _ = _get_source(main)
assert "%cluster_ctaid.x" in src
assert "%cluster_ctaid.y" in src
assert "%cluster_ctarank" not in src
assert "cooperative_groups::cluster_group::block_index" not in src
def test_cuda_handle_uint64_reinterpret_codegen():
@Tx.prim_func
def main(A: Tx.Buffer((1,), "uint64")):
with Tx.kernel():
tx = Tx.thread_id([32])
if Tx.filter(tx, tx == 0):
with Tx.thread():
ptr = Tx.reinterpret("handle", A[0])
A[0] = Tx.reinterpret("uint64", ptr)
src, _ = _get_source(main)
assert "reinterpret_cast<void*>" in src
assert "reinterpret_cast<uint64_t>" in src
assert "*(void* *)" not in src
def test_cuda_atomic_add():
@Tx.prim_func
def main(A: Tx.Buffer((1,), "int32"), B: Tx.Buffer((1,), "float32")):
with Tx.kernel():
cta_id = Tx.cta_id([1])
tx = Tx.thread_id([32])
if Tx.filter(tx, tx == 0):
with Tx.thread():
Tx.cuda.atomic_add(A.data, Tx.int32(1))
Tx.cuda.atomic_add(B.data, Tx.float32(1.0))
src, mod = _get_source(main)
assert "tvm_builtin_cuda_atomic_add" in src
A_np = np.zeros(1, dtype="int32")
B_np = np.zeros(1, dtype="float32")
A_tvm = tvm.runtime.tensor(A_np, device=DEV)
B_tvm = tvm.runtime.tensor(B_np, device=DEV)
mod["main"](A_tvm, B_tvm)
np.testing.assert_allclose(A_tvm.numpy(), 1)
np.testing.assert_allclose(B_tvm.numpy(), 1.0)
def test_ptx_ld_acquire_and_volatile_codegen():
@Tx.prim_func
def main(
A: Tx.Buffer((1,), "uint64"), B: Tx.Buffer((1,), "int32"), C: Tx.Buffer((1,), "uint32")
):
with Tx.kernel():
tx = Tx.thread_id([32])
if Tx.filter(tx, tx == 0):
with Tx.thread():
A[0] = Tx.ptx.ld_acquire(A.data, "uint64", "u64", scope="gpu", space="global")
B[0] = Tx.ptx.ld_acquire(B.data, "int32", "s32", scope="sys", space="global")
C[0] = Tx.ptx.ld_acquire(C.data, "uint32", "b32", scope="gpu", space="global")
Tx.ptx.ld_global_acquire(B[0], B.data)
A[0] = Tx.ptx.ld_volatile(A.data, "uint64", "u64", space="global")
src, _ = _get_source(main)
assert "ld.acquire.gpu.global.u64" in src
assert "ld.acquire.sys.global.s32" in src
assert "ld.acquire.gpu.global.b32" in src
assert "ptx_ld_global_acquire_int32" in src
assert "ptx_ld_global_acquire_b32" not in src
assert "ld.volatile.global.u64" in src
def test_megamoe_extracted_intrinsics_codegen():
@Tx.prim_func
def main(
U32: Tx.Buffer((4,), "uint32"),
I32: Tx.Buffer((1,), "int32"),
U64: Tx.Buffer((1,), "uint64"),
F32: Tx.Buffer((4,), "float32"),
):
with Tx.kernel():
tx = Tx.thread_id([32])
if Tx.filter(tx, tx == 0):
with Tx.thread():
Tx.ptx.red_scalar(
U64.data,
U64[0],
sem="release",
scope="gpu",
space="global",
op="or",
ptx_type="b64",
)
Tx.ptx.red_scalar(
I32.data,
I32[0],
sem="release",
scope="sys",
space="global",
op="add",
ptx_type="s32",
)
U32[0] = Tx.ptx.atom_scalar(
U32.data,
U32[0],
sem="release",
scope="gpu",
space="global",
op="add",
ptx_type="u32",
)
U64[0] = Tx.ptx.atom_scalar(
U64.data, U64[0], scope="sys", space="global", op="add", ptx_type="u64"
)
Tx.ptx.red_scalar(
U32.data, U32[0], scope="gpu", space="global", op="add", ptx_type="u32"
)
Tx.ptx.st(U32.data, U32[0], space="shared", ptx_type="u32")
Tx.ptx.st(
U32.data,
U32[0],
U32[1],
U32[2],
U32[3],
space="shared",
vec="v4",
ptx_type="b32",
)
Tx.ptx.st_bulk(U32.data, Tx.uint32(16), weak=True, space="shared::cta")
U32[0] = Tx.ptx.fns_b32(U32[0], U32[1], I32[0])
Tx.ptx.stmatrix(
U32.data,
U32.data,
num=1,
trans=True,
shape="m16n8",
ptx_type="b8",
space="shared",
)
F32[1] = Tx.cuda.uint_as_float(U32[0])
F32[2] = Tx.ptx.ld(F32.data, "float32", "f32", space="global")
U32[3] = Tx.cuda.float_as_uint(F32[1])
F32[0] = Tx.ptx.add_rn_f32_bf16(F32[0], Tx.cast(U32[0], "uint16"))
U64[0] = Tx.reinterpret("uint64", U32.data)
U32[0] = Tx.cuda.ballot_sync(Tx.uint32(0xFFFFFFFF), I32[0])
I32[0] = Tx.cuda.ffs_u32(U32[0])
U32[0] = Tx.cuda.reduce_add_sync_u32(Tx.uint32(0xFFFFFFFF), U32[0])
U32[0] = Tx.cuda.reduce_min_sync_u32(Tx.uint32(0xFFFFFFFF), U32[0])
U64[0] = Tx.cuda.clock64()
U32[0] = Tx.cuda.float22bfloat162_rn(F32[0], F32[1])
src, _ = _get_source(main)
for snippet in [
"red.release.gpu.global.or.b64",
"red.release.sys.global.add.s32",
"atom.release.gpu.global.add.u32",
"atom.sys.global.add.u64",
"red.gpu.global.add.u32",
"st.shared.u32",
"st.shared.v4.b32",
"st.bulk.weak.shared::cta",
"fns.b32",
"stmatrix.sync.aligned.m16n8.x1.trans.shared.b8",
"ld.global.f32",
"add.rn.f32.bf16",
"__uint_as_float",
"__float_as_uint",
"__ballot_sync",
"__ffs",
"__reduce_add_sync",
"__reduce_min_sync",
"clock64()",
"__float22bfloat162_rn",
]:
assert snippet in src
def test_ptx_cp_async_bulk_non_tma_form_codegen():
@Tx.prim_func
def main(
A: Tx.Buffer((128,), "float32"),
B: Tx.Buffer((128,), "float32"),
C: Tx.Buffer((1,), "uint64"),
):
with Tx.kernel():
tx = Tx.thread_id([32])
if Tx.filter(tx, tx == 0):
with Tx.thread():
smem = Tx.alloc_shared([128], "float32")
Tx.ptx.cp_async_bulk_g2s_cta(
smem.ptr_to([0]), A.data, Tx.uint32(64), smem.ptr_to([0]), cache_policy=C[0]
)
Tx.ptx.cp_async_bulk_g2s_cluster(
smem.ptr_to([0]), A.data, Tx.uint32(64), smem.ptr_to([0]), cache_policy=C[0]
)
Tx.ptx.cp_async_bulk_s2g(
B.data, smem.ptr_to([0]), Tx.uint32(64), cache_policy=C[0]
)
src, _ = _get_source(main)
assert "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" in src
assert "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" in src
assert "cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint" in src
assert "unsigned long long cache_policy" in src
def test_tensor_map_param_codegen():
@Tx.prim_func
def main(A_map: Tx.TensorMap()):
with Tx.kernel():
tx = Tx.thread_id([32])
if Tx.filter(tx, tx == 0):
with Tx.thread():
Tx.evaluate(Tx.address_of(A_map))
src, _ = _get_source(main)
assert "const __grid_constant__ CUtensorMap A_map" in src
assert "((unsigned long long)(&(A_map)))" in src
def test_tma_cache_policy_operand_codegen():
@Tx.prim_func
def main(Cache: Tx.Buffer((1,), "uint64")):
A_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1)
B_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1)
with Tx.kernel():
tx = Tx.thread_id([32])
if Tx.filter(tx, tx == 0):
with Tx.thread():
smem = Tx.alloc_buffer((128,), "float32", scope="shared", align=128)
bar = Tx.shared_scalar("uint64")
Tx.ptx.cp_async.bulk.tensor.g2c(
2,
smem.data,
Tx.address_of(bar),
Tx.address_of(A_map),
1,
2,
"",
0,
0,
cache_policy=Cache[0],
)
Tx.ptx.cp_async.bulk.tensor.g2c(
2,
smem.data,
Tx.address_of(bar),
Tx.address_of(A_map),
3,
2,
"",
0,
0,
cache_policy=Cache[0],
)
Tx.ptx.cp_async.bulk.tensor.s2g(
2, smem.data, Tx.address_of(A_map), "", 0, 0, cache_policy=Cache[0]
)
masked_bar = Tx.cuda.sm100_tma_2sm_mbarrier_addr(Tx.address_of(bar))
Tx.ptx.cp_async.bulk.tensor.g2c_bar_addr(
2,
smem.data,
masked_bar,
Tx.address_of(A_map),
1,
2,
"",
0,
0,
cache_policy=Cache[0],
)
if tx == 0:
Tx.ptx.cp_async.bulk.tensor.g2c_bar_addr(
2,
smem.data,
masked_bar,
Tx.address_of(A_map),
1,
2,
"",
0,
0,
cache_policy=Cache[0],
)
else:
Tx.ptx.cp_async.bulk.tensor.g2c_bar_addr(
2,
smem.data,
masked_bar,
Tx.address_of(B_map),
1,
2,
"",
0,
0,
cache_policy=Cache[0],
)
src, _ = _get_source(main)
assert "ptx_cp_async_bulk_tensor_g2cluster_tile_2d_cache_hint" in src
assert "ptx_cp_async_bulk_tensor_g2cluster_tile_2d_multicast_cache_hint" in src
assert "g2cluster_unicast" not in src
assert "ptx_cp_async_bulk_tensor_g2cta" not in src
assert (
"cp.async.bulk.tensor.2d.shared::cluster.global"
".mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint"
) in src
assert (
"cp.async.bulk.tensor.2d.shared::cluster.global"
".mbarrier::complete_tx::bytes.multicast::cluster"
".cta_group::2.L2::cache_hint"
) in src
assert "cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group.L2::cache_hint" in src
assert "tvm_builtin_cp_async_bulk_tensor_2d_g2c_cta_group2" not in src
assert "tvm_builtin_cuda_cvta_generic_to_shared((&(bar_ptr[0]))) & (uint)4278190079" in src
assert "ptx_cp_async_bulk_tensor_g2cluster_tile_2d_cache_hint_bar_addr" in src
assert "unsigned long long cache_policy" in src
def test_cuda_thread_fence():
@Tx.prim_func
def main(A: Tx.Buffer((16, 16), "int32")):
with Tx.kernel():
cta_id = Tx.cta_id([1])
tx = Tx.thread_id([32])
if Tx.filter(tx, tx == 0):
with Tx.thread():
Tx.cuda.thread_fence()
src, mod = _get_source(main)
assert "tvm_builtin_cuda_thread_fence" in src
def test_cuda_nano_sleep():
@Tx.prim_func
def main(A: Tx.Buffer((16, 16), "int32")):
with Tx.kernel():
cta_id = Tx.cta_id([1])
tx = Tx.thread_id([32])
if Tx.filter(tx, tx == 0):
with Tx.thread():
Tx.cuda.nano_sleep(1)
src, mod = _get_source(main)
assert "tvm_builtin_cuda_nano_sleep" in src
def test_cuda_atomic_cas():
@Tx.prim_func
def main(A: Tx.Buffer((16, 16), "int32")):
with Tx.kernel():
cta_id = Tx.cta_id([1])
tx = Tx.thread_id([32])
if Tx.filter(tx, tx == 0):
with Tx.thread():
Tx.cuda.atomic_cas(A.data, Tx.int32(1), Tx.int32(2))
src, mod = _get_source(main)
assert "tvm_builtin_cuda_atomic_cas" in src
def test_cuda_func_call():
def test_add_one():
add_one = """
__device__ int32_t add_one(int32_t a) {
return a + 1;
}
"""
@Tx.prim_func
def main(a: Tx.Buffer((16, 16), "int32"), b: Tx.Buffer((16, 16), "int32")):
with Tx.kernel():
cta_id = Tx.cta_id([1])
tx = Tx.thread_id([32])
if Tx.filter(tx, tx == 0):
with Tx.thread():
for i, j in Tx.grid(16, 16):
b[i, j] = Tx.cuda.func_call(
"add_one", a[i, j], source_code=add_one, return_type="int32"
)
src, mod = _get_source(main)
A = np.random.randint(0, 10, (16, 16)).astype("int32")
B = np.zeros((16, 16), dtype="int32")
A_tvm = tvm.runtime.tensor(A, device=DEV)
B_tvm = tvm.runtime.tensor(B, device=DEV)
mod["main"](A_tvm, B_tvm)
np.testing.assert_allclose(B_tvm.numpy(), A + 1)
print(src)
test_add_one()
def test_print():
print_func = """
__device__ void print(int32_t a) {
printf("%d\\n", a);
}
"""
@Tx.prim_func
def main(a: Tx.Buffer((16, 16), "int32")):
with Tx.kernel():
cta_id = Tx.cta_id([1])
tx = Tx.thread_id([32])
if Tx.filter(tx, tx == 0):
with Tx.thread():
for i, j in Tx.grid(16, 16):
Tx.cuda.func_call("print", a[i, j], source_code=print_func)
src, mod = _get_source(main)
A = np.random.randint(0, 10, (16, 16)).astype("int32")
A_tvm = tvm.runtime.tensor(A, device=DEV)
mod["main"](A_tvm)
print(src)
test_print()
def test_warp_shuffle_xor_sync():
# fmt: off
@Tx.prim_func
def func(A_ptr: Tx.handle):
A = Tx.match_buffer(A_ptr, (32,), dtype="float32", align=16)
with Tx.kernel():
cta_id = Tx.cta_id([1])
warp_id = Tx.warp_id([1])
lane_id = Tx.lane_id([32])
with Tx.thread():
A_local = Tx.alloc_buffer([1], "float32", scope="local")
i = Tx.alloc_buffer([1], "int32", scope="local")
A_local[0] = Tx.float32(31 - lane_id)
i[0] = 16
while i[0] >= 1:
A_local[0] += Tx.tvm_warp_shuffle_xor(0xFFFFFFFF, A_local[0], i[0], 32, 32)
i[0] = i[0] // 2
A[lane_id] = A_local[0]
# fmt: on
DEV = tvm.cuda(0)
target = tvm.target.Target("cuda")
mod = tvm.IRModule({"main": func})
mod = tvm.compile(mod, target=target, tir_pipeline="tirx")
A_np = np.zeros(32, dtype="float32")
A = tvm.runtime.tensor(A_np, device=DEV)
mod(A)
assert "__shfl_xor_sync" in mod.mod.imports[0].inspect_source()
A_ref = np.ones(32, dtype="float32") * 496
np.testing.assert_allclose(A.numpy(), A_ref)
@pytest.mark.parametrize("cp_size", [4, 8, 16])
@pytest.mark.parametrize("cache_hint", ["", "evict_last"])
@pytest.mark.parametrize("prefetch_size", [-1, 64, 128, 256])
@pytest.mark.parametrize("predicate", [-1, Tx.int32(0), Tx.int32(1)])
@pytest.mark.parametrize("fill_mode", ["", "zero"])
def test_ptx_cp_async(cp_size, cache_hint, prefetch_size, predicate, fill_mode):
if fill_mode != "" and predicate == -1:
return
N = cp_size // 2
# fmt: off
@Tx.prim_func
def main(A: Tx.Buffer((N), "float16")):
with Tx.kernel():
cta_id = Tx.cta_id([1])
tid = Tx.thread_id([32])
with Tx.thread():
A_shared = Tx.alloc_shared([N], "float16")
for i in Tx.vectorized(N):
A_shared[i] = 5.0
Tx.ptx.fence.proxy_async("shared::cta")
Tx.ptx.cp_async(A_shared.ptr_to([0]), A.ptr_to([0]), cp_size, cache_hint=cache_hint, prefetch_size=prefetch_size, predicate=predicate, fill_mode=fill_mode) # noqa: E501
Tx.ptx.cp_async.commit_group()
Tx.ptx.cp_async.wait_group(0)
for i in Tx.serial(N):
A[i] = A_shared[i] + 1.0
# fmt: on
src, mod = _get_source(main)
A_np = np.ones(N, dtype="float16")
A = tvm.runtime.tensor(A_np, device=DEV)
mod(A)
A_ref = np.ones(N, dtype="float16") * 2
if int(predicate) == 0:
if fill_mode == "zero":
A_ref = np.ones(N, dtype="float16")
else:
A_ref = np.ones(N, dtype="float16") * 6
np.testing.assert_allclose(A.numpy(), A_ref)
print(src)
@pytest.mark.parametrize("trans", [False, True])
@pytest.mark.parametrize("num", [1, 2, 4])
def test_ptx_ldmatrix(trans, num):
dtype = ".b16"
# fmt: off
@Tx.prim_func
def main(A: Tx.Buffer((16, 16), "float16"), B: Tx.Buffer((16, 16), "float16")):
with Tx.kernel():
cta_id = Tx.cta_id([1])
tx = Tx.thread_id([32])
A_shared = Tx.alloc_shared([16, 16], "float16")
if Tx.filter(tx, tx == 0):
with Tx.thread():
for i, j in Tx.grid(16, 16):
A_shared[i, j] = A[i, j]
Tx.cuda.cta_sync()
with Tx.thread():
A_local = Tx.alloc_local([8], "float16")
A_local[0] = -1.0
# ldmatrix .x{num}.b16 writes `num` 32-bit registers; A_local
# is a contiguous fp16[8] buffer, so consecutive register
# destinations land 2 fp16 elements apart.
if num == 1:
Tx.ptx.ldmatrix(
trans, num, dtype,
A_shared.ptr_to([tx % 16, tx // 16 * 8]),
Tx.address_of(A_local[0]),
)
elif num == 2:
Tx.ptx.ldmatrix(
trans, num, dtype,
A_shared.ptr_to([tx % 16, tx // 16 * 8]),
Tx.address_of(A_local[0]),
Tx.address_of(A_local[2]),
)
else:
Tx.ptx.ldmatrix(
trans, num, dtype,
A_shared.ptr_to([tx % 16, tx // 16 * 8]),
Tx.address_of(A_local[0]),
Tx.address_of(A_local[2]),
Tx.address_of(A_local[4]),
Tx.address_of(A_local[6]),
)
for i in range(8):
row: Tx.let = (i // 2) % 2 * 8
col: Tx.let = (i // 4) * 8
B[row + tx // 4, col + tx % 4 * 2 + i % 2] = A_local[i]
# fmt: on
src, mod = _get_source(main)
A_np = np.arange(16 * 16, dtype="float16").reshape((16, 16))
A = tvm.runtime.tensor(A_np, device=DEV)
B_np = np.zeros((16, 16), dtype="float16")
B_ref = np.zeros((16, 16), dtype="float16")
B = tvm.runtime.tensor(B_np, device=DEV)
mod(A, B)
if num == 1:
B_ref[0:8, 0:8] = A_np[0:8, 0:8] if not trans else A_np[0:8, 0:8].T
elif num == 2:
B_ref[0:8, 0:8] = A_np[0:8, 0:8] if not trans else A_np[0:8, 0:8].T
B_ref[8:16, 0:8] = A_np[8:16, 0:8] if not trans else A_np[8:16, 0:8].T
elif num == 4:
B_ref[0:8, 0:8] = A_np[0:8, 0:8] if not trans else A_np[0:8, 0:8].T
B_ref[0:8, 8:16] = A_np[0:8, 8:16] if not trans else A_np[0:8, 8:16].T
B_ref[8:16, 0:8] = A_np[8:16, 0:8] if not trans else A_np[8:16, 0:8].T
B_ref[8:16, 8:16] = A_np[8:16, 8:16] if not trans else A_np[8:16, 8:16].T
np.testing.assert_allclose(B.numpy(), B_ref)
@pytest.mark.parametrize("d_type", ["float16", "float32"])
@pytest.mark.parametrize("no_c_ptr", [False, True])
def test_ptx_mma_half_m16n8k16(d_type, no_c_ptr):
shape = "m16n8k16"
a_type = "float16"
b_type = "float16"
c_type = d_type
a_layout = "row"
b_layout = "col"
# fmt: off
@Tx.prim_func
def main(
D: Tx.Buffer((16, 8), d_type),
A: Tx.Buffer((16, 16), a_type),
B: Tx.Buffer((16, 8), b_type),
C: Tx.Buffer((16, 8), c_type),
):
with Tx.kernel():
cta_id = Tx.cta_id([1])
tx = Tx.thread_id([32])
with Tx.thread():
D_local = Tx.alloc_local([4], d_type)
A_local = Tx.alloc_local([8], a_type)
B_local = Tx.alloc_local([4], b_type)
C_local = Tx.alloc_local([4], c_type)
@Tx.inline
def G2L(buf_local, buf_global, block_8x8, mode="row"):
if mode == "row":
for i in range(block_8x8):
row = Tx.meta_var(i % 2 * 8 + tx // 4)
col = Tx.meta_var(i // 2 * 8 + (tx % 4) * 2)
for j in range(2):
buf_local[i * 2 + j] = buf_global[row, col + j]
elif mode == "col":
for i in range(block_8x8):
row = Tx.meta_var(i % 2 * 8 + (tx % 4) * 2)
col = Tx.meta_var(i // 2 * 8 + tx // 4)
for j in range(2):
buf_local[i * 2 + j] = buf_global[row + j, col]
@Tx.inline
def L2G(buf_local, buf_global, block_8x8):
for i in range(block_8x8):
row = Tx.meta_var(i % 2 * 8 + tx // 4)
col = Tx.meta_var(i // 2 * 8 + (tx % 4) * 2)
for j in range(2):
buf_global[row, col + j] = buf_local[i * 2 + j]
G2L(D_local, D, 2)
G2L(A_local, A, 4)
G2L(B_local, B, 2, "col")
G2L(C_local, C, 2)
if no_c_ptr:
Tx.ptx.mma(shape, a_layout, b_layout, d_type, a_type, b_type, c_type,
D_local.ptr_to([0]), A_local.ptr_to([0]), B_local.ptr_to([0]))
else:
Tx.ptx.mma(shape, a_layout, b_layout, d_type, a_type, b_type, c_type,
D_local.ptr_to([0]), A_local.ptr_to([0]), B_local.ptr_to([0]), C_local.ptr_to([0])) # noqa: E501
L2G(D_local, D, 2)
# fmt: on
src, mod = _get_source(main)
np.random.seed(0)
D_np = np.zeros((16, 8), dtype=d_type)
A_np = np.random.randn(16, 16).astype(a_type)
B_np = np.random.randn(16, 8).astype(b_type)
C_np = np.random.randn(16, 8).astype(c_type)
D = tvm.runtime.tensor(D_np, device=DEV)
A = tvm.runtime.tensor(A_np, device=DEV)
B = tvm.runtime.tensor(B_np, device=DEV)
C = tvm.runtime.tensor(C_np, device=DEV)
mod(D, A, B, C)
D_torch = torch.zeros((16, 8), dtype=torch.float16)
A_torch = torch.from_numpy(A_np)
B_torch = torch.from_numpy(B_np)
C_torch = torch.from_numpy(C_np)
if no_c_ptr:
D_torch = A_torch @ B_torch
else:
D_torch = A_torch @ B_torch + C_torch
np.testing.assert_allclose(D.numpy(), D_torch.numpy(), atol=1e-3, rtol=1e-3)
@pytest.mark.parametrize("d_type", ["float16", "float32"])
@pytest.mark.parametrize("no_c_ptr", [False, True])
def test_ptx_mma_half_m16n8k8(d_type, no_c_ptr):
shape = "m16n8k8"
a_type = "float16"
b_type = "float16"
c_type = d_type
a_layout = "row"
b_layout = "col"
# fmt: off
@Tx.prim_func
def main(
D: Tx.Buffer((16, 8), d_type),
A: Tx.Buffer((16, 8), a_type),
B: Tx.Buffer((8, 8), b_type),
C: Tx.Buffer((16, 8), c_type),
):
with Tx.kernel():
cta_id = Tx.cta_id([1])
tx = Tx.thread_id([32])
with Tx.thread():
D_local = Tx.alloc_local([4], d_type)
A_local = Tx.alloc_local([4], a_type)
B_local = Tx.alloc_local([2], b_type)
C_local = Tx.alloc_local([4], c_type)
@Tx.inline
def G2L(buf_local, buf_global, block_8x8, mode="row"):
if mode == "row":
for i in range(block_8x8):
row = Tx.meta_var(i % 2 * 8 + tx // 4)
col = Tx.meta_var(i // 2 * 8 + (tx % 4) * 2)
for j in range(2):
buf_local[i * 2 + j] = buf_global[row, col + j]
elif mode == "col":
for i in range(block_8x8):
row = Tx.meta_var(i % 2 * 8 + (tx % 4) * 2)
col = Tx.meta_var(i // 2 * 8 + tx // 4)
for j in range(2):
buf_local[i * 2 + j] = buf_global[row + j, col]
@Tx.inline
def L2G(buf_local, buf_global, block_8x8):
for i in range(block_8x8):
row = Tx.meta_var(i % 2 * 8 + tx // 4)
col = Tx.meta_var(i // 2 * 8 + (tx % 4) * 2)
for j in range(2):
buf_global[row, col + j] = buf_local[i * 2 + j]
G2L(D_local, D, 2)
G2L(A_local, A, 2)
G2L(B_local, B, 1, "col")
G2L(C_local, C, 2)
if no_c_ptr:
Tx.ptx.mma(shape, a_layout, b_layout, d_type, a_type, b_type, c_type,
D_local.ptr_to([0]), A_local.ptr_to([0]), B_local.ptr_to([0]))
else:
Tx.ptx.mma(shape, a_layout, b_layout, d_type, a_type, b_type, c_type,
D_local.ptr_to([0]), A_local.ptr_to([0]), B_local.ptr_to([0]), C_local.ptr_to([0])) # noqa: E501
L2G(D_local, D, 2)
# fmt: on
src, mod = _get_source(main)
np.random.seed(0)
D_np = np.zeros((16, 8), dtype=d_type)
A_np = np.random.randn(16, 8).astype(a_type)
B_np = np.random.randn(8, 8).astype(b_type)
C_np = np.random.randn(16, 8).astype(c_type)
D = tvm.runtime.tensor(D_np, device=DEV)
A = tvm.runtime.tensor(A_np, device=DEV)
B = tvm.runtime.tensor(B_np, device=DEV)
C = tvm.runtime.tensor(C_np, device=DEV)
mod(D, A, B, C)
D_torch = torch.zeros((16, 8), dtype=torch.float16)
A_torch = torch.from_numpy(A_np)
B_torch = torch.from_numpy(B_np)
C_torch = torch.from_numpy(C_np)
if no_c_ptr:
D_torch = A_torch @ B_torch
else:
D_torch = A_torch @ B_torch + C_torch
np.testing.assert_allclose(D.numpy(), D_torch.numpy(), atol=1e-3, rtol=1e-3)
if __name__ == "__main__":
tvm.testing.main()