blob: 25d14f9abdf3cbed36ac6c9937b7a06004898cb9 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Example code to do square matrix multiplication."""
import tvm
from tvm import te
import os
from tvm.contrib import nvcc
from tvm.contrib import spirv
import numpy as np
TASK = "gemm"
USE_MANUAL_CODE = False
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
def write_code(code, fname):
with open(fname, "w") as f:
f.write(code)
@tvm.register_func
def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"):
os.mkdir("perf")
write_code(code, "perf/%s_generated.cu" % TASK)
if USE_MANUAL_CODE:
code = open("perf/%s_manual.cu" % TASK).read()
return code
def test_gemm():
# graph
nn = 2048
n = te.var("n")
n = tvm.runtime.convert(nn)
m, l = n, n
A = te.placeholder((l, n), name="A")
B = te.placeholder((l, m), name="B")
k = te.reduce_axis((0, l), name="k")
C = te.compute((m, n), lambda ii, jj: te.sum(A[k, jj] * B[k, ii], axis=k), name="C")
# schedule
s = te.create_schedule(C.op)
AA = s.cache_read(A, "shared", [C])
BB = s.cache_read(B, "shared", [C])
AL = s.cache_read(AA, "local", [C])
BL = s.cache_read(BB, "local", [C])
CC = s.cache_write(C, "local")
scale = 8
num_thread = 8
block_factor = scale * num_thread
block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis((0, num_thread), "threadIdx.x")
block_y = te.thread_axis("blockIdx.y")
thread_y = te.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = te.thread_axis((0, 2), "vthread", name="vx")
thread_yz = te.thread_axis((0, 2), "vthread", name="vy")
by, yi = s[C].split(C.op.axis[0], factor=block_factor)
bx, xi = s[C].split(C.op.axis[1], factor=block_factor)
s[C].bind(by, block_y)
s[C].bind(bx, block_x)
s[C].reorder(by, bx, yi, xi)
tyz, yi = s[C].split(yi, nparts=2)
ty, yi = s[C].split(yi, nparts=num_thread)
txz, xi = s[C].split(xi, nparts=2)
tx, xi = s[C].split(xi, nparts=num_thread)
s[C].bind(tyz, thread_yz)
s[C].bind(txz, thread_xz)
s[C].bind(ty, thread_y)
s[C].bind(tx, thread_x)
s[C].reorder(tyz, txz, ty, tx, yi, xi)
s[CC].compute_at(s[C], tx)
yo, xo = CC.op.axis
ko, ki = s[CC].split(k, factor=8)
kt, ki = s[CC].split(ki, factor=1)
s[CC].reorder(ko, kt, ki, yo, xo)
s[AA].compute_at(s[CC], ko)
s[BB].compute_at(s[CC], ko)
s[CC].unroll(kt)
s[AL].compute_at(s[CC], kt)
s[BL].compute_at(s[CC], kt)
# Schedule for A's shared memory load
ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
_, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread * 4)
tx, xi = s[AA].split(xi, nparts=num_thread)
s[AA].bind(ty, thread_y)
s[AA].bind(tx, thread_x)
s[AA].vectorize(xi)
# Schedule for B' shared memory load
ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread)
_, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread * 4)
tx, xi = s[BB].split(xi, nparts=num_thread)
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
s[BB].vectorize(xi)
s[AA].double_buffer()
s[BB].double_buffer()
# correctness
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Device %s" % device)
f = tvm.build(s, [A, B, C], device)
# launch the kernel.
n, m, l = nn, nn, nn
a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
b_np = np.random.uniform(size=(m, l)).astype(B.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
for i in range(2):
f(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), np.dot(b_np.T, a_np), rtol=1e-5)
num_flops = 2 * nn * nn * nn
num_runs = 10
timer_f = f.time_evaluator(f.entry_name, ctx, number=num_runs)
t = timer_f(a, b, c).mean
GFLOPS = num_flops / (t * 1e3) / 1e6
print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS))
for device in ["cuda", "opencl", "rocm", "nvptx", "vulkan"]:
with tvm.transform.PassContext(
config={"tir.UnrollLoop": {"auto_max_step": 128, "explicit_unroll": device != "cuda"}}
):
check_device(device)
if __name__ == "__main__":
test_gemm()