blob: f98c10c8b9e6975be3af4934818c6d24d0b91ed7 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
import tvm
from tvm import te
from tvm.tir.tensor_intrin.rocm import (
shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B,
shared_16x16_to_local_64x4_layout_C,
ROCM_MFMA_fill_16x16_f32_INTRIN,
ROCM_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN,
ROCM_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN,
ROCM_MFMA_f32f32f32_INTRIN,
ROCM_MFMA_STORE_16x16_f32_INTRIN,
ROCM_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN,
ROCM_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN,
ROCM_MFMA_f16f16f32_INTRIN,
ROCM_MFMA_STORE_16x16_f32_INTRIN,
ROCM_MFMA_fill_16x16_i32_INTRIN,
ROCM_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN,
ROCM_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN,
ROCM_MFMA_s8s8s32_INTRIN,
ROCM_MFMA_STORE_16x16_s32_INTRIN,
)
import tvm.testing
import numpy as np
from tvm.testing.tir import mfma_schedule
M = 1024
N = 1024
K = 1024
measure_perf = False
gflops = (N * M * K) * 2 / 1e9
def matmul(m, n, k, in_dtype, out_dtype, b_transposed):
b_shape = (n, k) if b_transposed else (k, n)
a = te.placeholder((m, k), name="A", dtype=in_dtype)
b = te.placeholder(b_shape, name="B", dtype=in_dtype)
k = te.reduce_axis((0, k), name="k")
def maybe_cast(v):
if in_dtype != out_dtype:
return tvm.tir.Cast(out_dtype, v)
return v
def maybe_swap(i, j):
if b_transposed:
return j, i
return i, j
c = te.compute(
(m, n),
lambda i, j: te.sum(maybe_cast(a[i, k]) * maybe_cast(b[maybe_swap(k, j)]), axis=[k]),
name="C",
)
return (a, b, c)
def run_test(
k_inner,
in_dtype,
out_dtype,
b_transposed,
i_factors,
j_factors,
k_factors,
index_map_A,
index_map_B,
index_map_C,
ldmatrix_a_intrin,
ldmatrix_b_intrin,
mma_intrin,
mma_fill_intrin,
mma_store_intrin,
):
sch = mfma_schedule(
te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, b_transposed)),
k_inner,
in_dtype,
b_transposed,
i_factors,
j_factors,
k_factors,
index_map_A,
index_map_B,
index_map_C,
ldmatrix_a_intrin,
ldmatrix_b_intrin,
mma_intrin,
mma_fill_intrin,
mma_store_intrin,
)
f = tvm.compile(sch.mod["main"], target="rocm")
dev = tvm.device("rocm", 0)
if in_dtype == "float32":
a_np = np.random.uniform(size=(M, K)).astype("float32")
if b_transposed:
b_np = np.random.uniform(size=(N, K)).astype("float32")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype(
out_dtype
)
else:
b_np = np.random.uniform(size=(K, N)).astype("float32")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype)
elif in_dtype == "float16":
a_np = np.random.uniform(size=(M, K)).astype("float16")
if b_transposed:
b_np = np.random.uniform(size=(N, K)).astype("float16")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype(
out_dtype
)
else:
b_np = np.random.uniform(size=(K, N)).astype("float16")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype)
else:
a_np = np.random.randint(-128, 128, (M, K)).astype("int8")
if b_transposed:
b_np = np.random.randint(-128, 128, (N, K)).astype("int8")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype(
"int32"
)
else:
b_np = np.random.randint(-128, 128, (K, N)).astype("int8")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype("int32")
a = tvm.runtime.tensor(a_np, dev)
b = tvm.runtime.tensor(b_np, dev)
c = tvm.runtime.tensor(np.zeros((M, N), dtype=out_dtype), dev)
f(a, b, c)
if in_dtype != "float16":
# The numpy reference is computed with fp32 precision (otherwise too slow).
# So there is non-trivial accuracy difference if TVM result is computed with fp16 accumulation.
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-2, atol=1e-2)
return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c)
@tvm.testing.requires_matrixcore
def test_i8i8i32_m16n16k16():
def index_map_A(i, j):
return (
i // 16,
j // 16,
*shared_16x16_to_local_64x4_layout_A(i % 16, j % 16),
)
def index_map_B(i, j):
return (
i // 16,
j // 16,
*shared_16x16_to_local_64x4_layout_B(i % 16, j % 16),
)
def index_map_C(i, j):
return (
i // 16,
j // 16,
*shared_16x16_to_local_64x4_layout_C(i % 16, j % 16),
)
k_inner = 16
in_dtype = "int8"
out_dtype = "int32"
i_factors, j_factors, k_factors = [1, 8, 2, 4, 1], [1, 16, 2, 1, 2], [32, 2, 1]
timer = run_test(
k_inner,
in_dtype,
out_dtype,
False, # b_transposed
i_factors,
j_factors,
k_factors,
index_map_A,
index_map_B,
index_map_C,
ROCM_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN,
ROCM_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN,
ROCM_MFMA_s8s8s32_INTRIN,
ROCM_MFMA_fill_16x16_i32_INTRIN,
ROCM_MFMA_STORE_16x16_s32_INTRIN,
)
if measure_perf and timer:
print("test_i8i8i32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))
@tvm.testing.requires_matrixcore
def test_f16f16f32_m16n16k16():
def index_map_A(i, j):
return (
i // 16,
j // 16,
*shared_16x16_to_local_64x4_layout_A(i % 16, j % 16),
)
def index_map_B(i, j):
return (
i // 16,
j // 16,
*shared_16x16_to_local_64x4_layout_B(i % 16, j % 16),
)
def index_map_C(i, j):
return (
i // 16,
j // 16,
*shared_16x16_to_local_64x4_layout_C(i % 16, j % 16),
)
k_inner = 16
in_dtype = "float16"
out_dtype = "float32"
i_factors, j_factors, k_factors = [1, 8, 2, 4, 1], [1, 16, 2, 1, 2], [32, 2, 1]
timer = run_test(
k_inner,
in_dtype,
out_dtype,
False, # b_transposed
i_factors,
j_factors,
k_factors,
index_map_A,
index_map_B,
index_map_C,
ROCM_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN,
ROCM_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN,
ROCM_MFMA_f16f16f32_INTRIN,
ROCM_MFMA_fill_16x16_f32_INTRIN,
ROCM_MFMA_STORE_16x16_f32_INTRIN,
)
if measure_perf and timer:
print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))
@tvm.testing.requires_matrixcore
def test_f32f32f32_m16n16k4():
def index_map_A(i, j):
return (
i // 16,
j // 16,
*shared_16x4_to_local_64x1_layout_A(i % 16, j % 16),
)
def index_map_B(i, j):
return (
i // 16,
j // 16,
*shared_4x16_to_local_64x1_layout_B(i % 16, j % 16),
)
def index_map_C(i, j):
return (
i // 16,
j // 16,
*shared_16x16_to_local_64x4_layout_C(i % 16, j % 16),
)
k_inner = 4
in_dtype = "float32"
out_dtype = "float32"
i_factors, j_factors, k_factors = [4, 2, 1, 4, 2], [4, 2, 2, 1, 4], [128, 2, 1]
timer = run_test(
k_inner,
in_dtype,
out_dtype,
False, # b_transposed
i_factors,
j_factors,
k_factors,
index_map_A,
index_map_B,
index_map_C,
ROCM_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN,
ROCM_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN,
ROCM_MFMA_f32f32f32_INTRIN,
ROCM_MFMA_fill_16x16_f32_INTRIN,
ROCM_MFMA_STORE_16x16_f32_INTRIN,
)
if measure_perf and timer:
print("test_f32f32f32_m16n16k4: %f GFLOPS" % (gflops / (timer().mean)))
if __name__ == "__main__":
tvm.testing.main()