blob: 1d042610ac071cdba683b3a5be5d13a783a6f0de [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test workload for lowering and build."""
import numpy as np
import tvm
import tvm.testing
from tvm.script import tir as T
# complains that index_i is defined outside of a block
@T.prim_func(check_well_formed=False)
def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle, handle_c: T.handle) -> None:
# pylint: disable=missing-function-docstring
# match buffer
match_buffer_a = T.match_buffer(handle_a, [1024, 1024], "float16")
match_buffer_b = T.match_buffer(handle_b, [1024, 1024], "float16")
match_buffer_c = T.match_buffer(handle_c, [1024, 1024], "float32")
# body
for block_idx_x in T.thread_binding(0, 16, "blockIdx.x"):
for block_idx_y in T.thread_binding(0, 8, "blockIdx.y"):
with T.block():
axis_bx, axis_by = T.axis.remap("SS", [block_idx_x, block_idx_y])
shared_a = T.alloc_buffer([1024, 1024], "float16", scope="shared")
shared_b = T.alloc_buffer([1024, 1024], "float16", scope="shared")
wmma_a = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a")
wmma_b = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b")
wmma_c = T.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator")
# pylint: disable=too-many-nested-blocks
for thread_ty in T.thread_binding(0, 2, "threadIdx.y"):
for thread_tz in T.thread_binding(0, 2, "threadIdx.z"):
for index_i, index_jj in T.grid(2, 4):
with T.block():
new_axis_vi = T.axis.S(64, axis_bx * 4 + thread_ty * 2 + index_i)
new_axis_vj = T.axis.S(64, axis_by * 8 + thread_tz * 4 + index_jj)
T.reads([])
T.writes(
wmma_c[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
new_axis_vj * 16 : new_axis_vj * 16 + 16,
]
)
match_buffer_c0 = T.match_buffer(
wmma_c[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
new_axis_vj * 16 : new_axis_vj * 16 + 16,
],
(16, 16),
"float32",
strides=[16 * 4, 1],
scope="wmma.accumulator",
offset_factor=1,
)
T.evaluate(
T.tvm_fill_fragment(
match_buffer_c0.data,
16,
16,
16,
index_i * 4 + index_jj,
T.float32(0), # pylint: disable=not-callable
dtype="handle",
)
)
for k_o in range(0, 32):
# copy data from global to shared
for thread_tx in T.thread_binding(0, 32, "threadIdx.x"):
for index_i0, index_j0 in T.grid(1, 4):
for index_j1 in T.vectorized(0, 4):
with T.block():
new_axis_vi = T.axis.S(
1024,
axis_bx * 64
+ thread_ty * 32
+ thread_tx
+ index_i0,
)
new_axis_vj = T.axis.S(
1024,
k_o * 32 + thread_tz * 16 + index_j0 * 4 + index_j1,
)
shared_a[new_axis_vi, new_axis_vj + 8] = match_buffer_a[
new_axis_vi, new_axis_vj
]
for index_i0, index_j0 in T.grid(2, 4):
for index_j1 in T.vectorized(0, 4):
with T.block():
new_axis_vi = T.axis.S(
1024,
axis_by * 128
+ thread_ty * 64
+ thread_tx * 2
+ index_i0,
)
new_axis_vj = T.axis.S(
1024,
k_o * 32 + thread_tz * 16 + index_j0 * 4 + index_j1,
)
shared_b[new_axis_vi, new_axis_vj + 8] = match_buffer_b[
new_axis_vi, new_axis_vj
]
for k_i in range(0, 2):
for index_i in range(0, 2):
with T.block():
new_axis_vi = T.axis.S(
64, axis_bx * 4 + thread_ty * 2 + index_i
)
axis_vk = T.axis.S(64, k_o * 2 + k_i)
T.reads(
shared_a[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
axis_vk * 16 : axis_vk * 16 + 16 + 8,
]
)
T.writes(
wmma_a[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
axis_vk * 16 : axis_vk * 16 + 16,
]
)
stride0 = T.int32()
stride1 = T.int32()
match_buffer_a0 = T.match_buffer(
shared_a[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
axis_vk * 16 : axis_vk * 16 + 16 + 8,
],
(16, 16 + 8),
"float16",
strides=[stride0, stride1],
scope="shared",
offset_factor=1,
)
wmma_a0 = T.match_buffer(
wmma_a[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
axis_vk * 16 : axis_vk * 16 + 16,
],
(16, 16),
"float16",
strides=[16, 1],
scope="wmma.matrix_a",
offset_factor=1,
)
T.evaluate(
T.tvm_load_matrix_sync(
wmma_a0.data,
16,
16,
16,
index_i,
T.tvm_access_ptr(
T.type_annotation(dtype="float16"),
match_buffer_a0.data,
match_buffer_a0.elem_offset + 8,
match_buffer_a0.strides[0],
1,
dtype="handle",
),
match_buffer_a0.strides[0],
"row_major",
dtype="handle",
)
)
for index_jj in range(0, 4):
with T.block():
new_axis_vj = T.axis.S(
64, axis_by * 8 + thread_tz * 4 + index_jj
)
axis_vk = T.axis.S(64, k_o * 2 + k_i)
T.reads(
shared_b[
new_axis_vj * 16 : new_axis_vj * 16 + 16,
axis_vk * 16 : axis_vk * 16 + 16 + 8,
]
)
T.writes(
wmma_b[
new_axis_vj * 16 : new_axis_vj * 16 + 16,
axis_vk * 16 : axis_vk * 16 + 16,
]
)
stride0 = T.int32()
stride1 = T.int32()
match_buffer_b0 = T.match_buffer(
shared_b[
new_axis_vj * 16 : new_axis_vj * 16 + 16,
axis_vk * 16 : axis_vk * 16 + 16 + 8,
],
(16, 16 + 8),
"float16",
strides=[stride0, stride1],
scope="shared",
offset_factor=1,
)
wmma_b0 = T.match_buffer(
wmma_b[
new_axis_vj * 16 : new_axis_vj * 16 + 16,
axis_vk * 16 : axis_vk * 16 + 16,
],
(16, 16),
"float16",
strides=[16, 1],
scope="wmma.matrix_b",
offset_factor=1,
)
T.evaluate(
T.tvm_load_matrix_sync(
wmma_b0.data,
16,
16,
16,
index_jj,
T.tvm_access_ptr(
T.type_annotation(dtype="float16"),
match_buffer_b0.data,
match_buffer_b0.elem_offset + 8,
match_buffer_b0.strides[0],
1,
dtype="handle",
),
match_buffer_b0.strides[0],
"col_major",
dtype="handle",
)
)
for index_i, index_jj in T.grid(2, 4):
with T.block():
new_axis_vi = T.axis.S(
64, axis_bx * 4 + thread_ty * 2 + index_i
)
new_axis_vj = T.axis.S(
64, axis_by * 8 + thread_tz * 4 + index_jj
)
axis_vk = T.axis.R(64, k_o * 2 + k_i)
T.reads(
[
wmma_a[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
axis_vk * 16 : axis_vk * 16 + 16,
],
wmma_b[
new_axis_vj * 16 : new_axis_vj * 16 + 16,
axis_vk * 16 : axis_vk * 16 + 16,
],
wmma_c[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
new_axis_vj * 16 : new_axis_vj * 16 + 16,
],
]
)
T.writes(
wmma_c[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
new_axis_vj * 16 : new_axis_vj * 16 + 16,
]
)
wmma_a1 = T.match_buffer(
wmma_a[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
axis_vk * 16 : axis_vk * 16 + 16,
],
(16, 16),
"float16",
strides=[16, 1],
scope="wmma.matrix_a",
offset_factor=1,
)
wmma_b1 = T.match_buffer(
wmma_b[
new_axis_vj * 16 : new_axis_vj * 16 + 16,
axis_vk * 16 : axis_vk * 16 + 16,
],
(16, 16),
"float16",
strides=[16, 1],
scope="wmma.matrix_b",
offset_factor=1,
)
wmma_c1 = T.match_buffer(
wmma_c[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
new_axis_vj * 16 : new_axis_vj * 16 + 16,
],
(16, 16),
"float32",
strides=[16 * 4, 1],
scope="wmma.accumulator",
offset_factor=1,
)
T.evaluate(
T.tvm_mma_sync(
wmma_c1.data,
index_i * 4 + index_jj,
wmma_a1.data,
index_i,
wmma_b1.data,
index_jj,
wmma_c1.data,
index_i * 4 + index_jj,
dtype="handle",
)
)
for index_i, index_jj in T.grid(2, 4):
with T.block():
new_axis_vi = T.axis.S(64, axis_bx * 4 + thread_ty * 2 + index_i)
new_axis_vj = T.axis.S(64, axis_by * 8 + thread_tz * 4 + index_jj)
T.reads(
wmma_c[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
new_axis_vj * 16 : new_axis_vj * 16 + 16,
]
)
T.writes(
match_buffer_c[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
new_axis_vj * 16 : new_axis_vj * 16 + 16,
]
)
stride0 = T.int32()
stride1 = T.int32()
wmma_c2 = T.match_buffer(
wmma_c[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
new_axis_vj * 16 : new_axis_vj * 16 + 16,
],
(16, 16),
"float32",
strides=[16 * 4, 1],
scope="wmma.accumulator",
offset_factor=1,
)
match_buffer_c1 = T.match_buffer(
match_buffer_c[
new_axis_vi * 16 : new_axis_vi * 16 + 16,
new_axis_vj * 16 : new_axis_vj * 16 + 16,
],
(16, 16),
"float32",
strides=[stride0, stride1],
offset_factor=1,
)
T.evaluate(
T.tvm_store_matrix_sync(
wmma_c2.data,
16,
16,
16,
index_i * 4 + index_jj,
T.tvm_access_ptr(
T.type_annotation(dtype="float32"),
match_buffer_c1.data,
match_buffer_c1.elem_offset,
match_buffer_c1.strides[0],
1,
dtype="handle",
),
match_buffer_c1.strides[0],
"row_major",
dtype="handle",
)
)
@tvm.testing.requires_cuda
def test_gemm_tensorcore():
"""Test running gemm on tensorcore."""
dev = tvm.device("cuda", 0)
a_np = np.random.uniform(size=(1024, 1024)).astype("float16")
b_np = np.random.uniform(size=(1024, 1024)).astype("float16")
c_np = np.dot(a_np.astype("float32"), b_np.T.astype("float32"))
buff_a = tvm.nd.array(a_np, dev)
buff_b = tvm.nd.array(b_np, dev)
buff_c = tvm.nd.array(np.zeros((1024, 1024), dtype="float32"), dev)
myfunc = tvm.build(tensorcore_gemm, target="cuda", name="dense")
myfunc(buff_a, buff_b, buff_c)
tvm.testing.assert_allclose(buff_c.numpy(), c_np, rtol=1e-3)
evaluator = myfunc.time_evaluator(myfunc.entry_name, dev, number=100)
time_elapsed = evaluator(buff_a, buff_b, buff_c).mean
num_flops = 2 * 1024 * 1024 * 1024
gflops = num_flops / (time_elapsed * 1e3) / 1e6
print("gemm with tensor core: %f ms" % (time_elapsed * 1e3))
print("GFLOPS: %f" % gflops)
if __name__ == "__main__":
test_gemm_tensorcore()