blob: f0c8911cd9b47fb93d6c46649e754aa5a8d37995 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring
"""Codegen tests for Ampere (sm_80) warp-level ``mma.sync`` tensor cores.
These exercise the ``T.ptx.mma`` intrinsic directly (not via the gemm
dispatch). ``ptx.mma`` takes one pointer per 32-bit register for each operand
(``d_ptrs`` / ``a_ptrs`` / ``b_ptrs`` / ``c_ptrs``), enumerated in the fixed
PTX register order, so the b32 registers may be scattered in the register file
while the two packed fp16/bf16 within a b32 stay contiguous. For m16n8k{8,16}
with f32 accumulation the per-lane register counts are:
A: 2 inputs per b32 -> k16: 4 b32 (regs 0,2,4,6); k8: 2 b32 (regs 0,2)
B: 2 inputs per b32 -> k16: 2 b32 (regs 0,2); k8: 1 b32 (reg 0)
D/C: 4 f32 accumulator registers (0,1,2,3)
"""
import numpy as np
import pytest
import tvm
import tvm.testing
from tvm.script import tirx as T
DEV = tvm.device("cuda")
def _get_source(func: tvm.tirx.PrimFunc):
target = tvm.target.Target("cuda")
mod = tvm.IRModule({"main": func})
mod = tvm.compile(mod, target=target, tir_pipeline="tirx")
src = mod.mod.imports[0].inspect_source()
return src, mod
def _np_in(dtype):
if dtype == "bfloat16":
return __import__("ml_dtypes").bfloat16
return np.float16
def _run_mma(mod, K, no_c_ptr, np_in):
"""Run an m16n8kK mma kernel and check D == A @ B (+ C) against numpy."""
np.random.seed(0)
A_np = np.random.randn(16, K).astype(np_in)
B_np = np.random.randn(K, 8).astype(np_in)
C_np = np.random.randn(16, 8).astype(np.float32)
D = tvm.runtime.tensor(np.zeros((16, 8), np.float32), device=DEV)
A = tvm.runtime.tensor(A_np, device=DEV)
B = tvm.runtime.tensor(B_np, device=DEV)
C = tvm.runtime.tensor(C_np, device=DEV)
mod(D, A, B, C)
ref = A_np.astype(np.float32) @ B_np.astype(np.float32)
if not no_c_ptr:
ref = ref + C_np
np.testing.assert_allclose(D.numpy(), ref, atol=1e-2, rtol=1e-2)
@tvm.testing.requires_cuda
@pytest.mark.parametrize("a_type", ["float16", "bfloat16"])
@pytest.mark.parametrize("no_c_ptr", [False, True])
def test_ptx_mma_m16n8k16(a_type, no_c_ptr):
"""m16n8k16 row.col mma, f32 accumulate: A is 16x16 (4 b32/lane), B is 16x8
as [K, N] (2 b32/lane), D/C is 16x8 (4 f32/lane)."""
if a_type == "bfloat16":
pytest.importorskip("ml_dtypes")
b_type = a_type
# fmt: off
@T.prim_func
def main(
D: T.Buffer((16, 8), "float32"),
A: T.Buffer((16, 16), a_type),
B: T.Buffer((16, 8), b_type),
C: T.Buffer((16, 8), "float32"),
):
T.device_entry()
cta_id = T.cta_id([1])
tx = T.thread_id([32])
D_local = T.alloc_local([4], "float32")
A_local = T.alloc_local([8], a_type)
B_local = T.alloc_local([4], b_type)
C_local = T.alloc_local([4], "float32")
@T.inline
def G2L(buf_local, buf_global, block_8x8, mode="row"):
if mode == "row":
for i in range(block_8x8):
row = T.meta_var(i % 2 * 8 + tx // 4)
col = T.meta_var(i // 2 * 8 + (tx % 4) * 2)
for j in range(2):
buf_local[i * 2 + j] = buf_global[row, col + j]
elif mode == "col":
for i in range(block_8x8):
row = T.meta_var(i % 2 * 8 + (tx % 4) * 2)
col = T.meta_var(i // 2 * 8 + tx // 4)
for j in range(2):
buf_local[i * 2 + j] = buf_global[row + j, col]
G2L(D_local, D, 2)
G2L(A_local, A, 4)
G2L(B_local, B, 2, "col")
G2L(C_local, C, 2)
# One pointer per b32 register, in PTX order: A=4, B=2, D/C=4.
d_ptrs = [D_local.ptr_to([i]) for i in range(4)]
a_ptrs = [A_local.ptr_to([2 * i]) for i in range(4)]
b_ptrs = [B_local.ptr_to([2 * i]) for i in range(2)]
if no_c_ptr:
T.ptx.mma("m16n8k16", "row", "col", "float32", a_type, b_type, "float32",
d_ptrs, a_ptrs, b_ptrs)
else:
c_ptrs = [C_local.ptr_to([i]) for i in range(4)]
T.ptx.mma("m16n8k16", "row", "col", "float32", a_type, b_type, "float32",
d_ptrs, a_ptrs, b_ptrs, c_ptrs)
for i in range(2):
row = T.meta_var(i % 2 * 8 + tx // 4)
col = T.meta_var(i // 2 * 8 + (tx % 4) * 2)
for j in range(2):
D[row, col + j] = D_local[i * 2 + j]
# fmt: on
src, mod = _get_source(main)
assert "mma.sync.aligned.m16n8k16.row.col" in src
_run_mma(mod, 16, no_c_ptr, _np_in(a_type))
@tvm.testing.requires_cuda
@pytest.mark.parametrize("a_type", ["float16", "bfloat16"])
@pytest.mark.parametrize("no_c_ptr", [False, True])
def test_ptx_mma_m16n8k8(a_type, no_c_ptr):
"""m16n8k8 row.col mma, f32 accumulate: A is 16x8 (2 b32/lane), B is 8x8
as [K, N] (1 b32/lane), D/C is 16x8 (4 f32/lane)."""
if a_type == "bfloat16":
pytest.importorskip("ml_dtypes")
b_type = a_type
# fmt: off
@T.prim_func
def main(
D: T.Buffer((16, 8), "float32"),
A: T.Buffer((16, 8), a_type),
B: T.Buffer((8, 8), b_type),
C: T.Buffer((16, 8), "float32"),
):
T.device_entry()
cta_id = T.cta_id([1])
tx = T.thread_id([32])
D_local = T.alloc_local([4], "float32")
A_local = T.alloc_local([4], a_type)
B_local = T.alloc_local([2], b_type)
C_local = T.alloc_local([4], "float32")
@T.inline
def G2L(buf_local, buf_global, block_8x8, mode="row"):
if mode == "row":
for i in range(block_8x8):
row = T.meta_var(i % 2 * 8 + tx // 4)
col = T.meta_var(i // 2 * 8 + (tx % 4) * 2)
for j in range(2):
buf_local[i * 2 + j] = buf_global[row, col + j]
elif mode == "col":
for i in range(block_8x8):
row = T.meta_var(i % 2 * 8 + (tx % 4) * 2)
col = T.meta_var(i // 2 * 8 + tx // 4)
for j in range(2):
buf_local[i * 2 + j] = buf_global[row + j, col]
G2L(D_local, D, 2)
G2L(A_local, A, 2)
G2L(B_local, B, 1, "col")
G2L(C_local, C, 2)
# One pointer per b32 register, in PTX order: A=2, B=1, D/C=4.
d_ptrs = [D_local.ptr_to([i]) for i in range(4)]
a_ptrs = [A_local.ptr_to([2 * i]) for i in range(2)]
b_ptrs = [B_local.ptr_to([0])]
if no_c_ptr:
T.ptx.mma("m16n8k8", "row", "col", "float32", a_type, b_type, "float32",
d_ptrs, a_ptrs, b_ptrs)
else:
c_ptrs = [C_local.ptr_to([i]) for i in range(4)]
T.ptx.mma("m16n8k8", "row", "col", "float32", a_type, b_type, "float32",
d_ptrs, a_ptrs, b_ptrs, c_ptrs)
for i in range(2):
row = T.meta_var(i % 2 * 8 + tx // 4)
col = T.meta_var(i // 2 * 8 + (tx % 4) * 2)
for j in range(2):
D[row, col + j] = D_local[i * 2 + j]
# fmt: on
src, mod = _get_source(main)
assert "mma.sync.aligned.m16n8k8.row.col" in src
_run_mma(mod, 8, no_c_ptr, _np_in(a_type))
if __name__ == "__main__":
tvm.testing.main()