| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """Example code to do square matrix multiplication.""" |
| import tvm |
| from tvm import te |
| import os |
| from tvm.contrib import nvcc |
| from tvm.contrib import spirv |
| import numpy as np |
| |
| TASK = "gemm" |
| USE_MANUAL_CODE = False |
| |
| |
| @tvm.register_func |
| def tvm_callback_cuda_compile(code): |
| ptx = nvcc.compile_cuda(code, target="ptx") |
| return ptx |
| |
| |
| def write_code(code, fname): |
| with open(fname, "w") as f: |
| f.write(code) |
| |
| |
| @tvm.register_func |
| def tvm_callback_cuda_postproc(code): |
| if not os.path.exists("perf"): |
| os.mkdir("perf") |
| write_code(code, "perf/%s_generated.cu" % TASK) |
| if USE_MANUAL_CODE: |
| code = open("perf/%s_manual.cu" % TASK).read() |
| return code |
| |
| |
| def test_gemm(): |
| # graph |
| nn = 2048 |
| n = te.var("n") |
| n = tvm.runtime.convert(nn) |
| m, l = n, n |
| A = te.placeholder((l, n), name="A") |
| B = te.placeholder((l, m), name="B") |
| k = te.reduce_axis((0, l), name="k") |
| C = te.compute((m, n), lambda ii, jj: te.sum(A[k, jj] * B[k, ii], axis=k), name="C") |
| |
| # schedule |
| s = te.create_schedule(C.op) |
| AA = s.cache_read(A, "shared", [C]) |
| BB = s.cache_read(B, "shared", [C]) |
| AL = s.cache_read(AA, "local", [C]) |
| BL = s.cache_read(BB, "local", [C]) |
| CC = s.cache_write(C, "local") |
| |
| scale = 8 |
| num_thread = 8 |
| block_factor = scale * num_thread |
| block_x = te.thread_axis("blockIdx.x") |
| thread_x = te.thread_axis((0, num_thread), "threadIdx.x") |
| block_y = te.thread_axis("blockIdx.y") |
| thread_y = te.thread_axis((0, num_thread), "threadIdx.y") |
| thread_xz = te.thread_axis((0, 2), "vthread", name="vx") |
| thread_yz = te.thread_axis((0, 2), "vthread", name="vy") |
| |
| by, yi = s[C].split(C.op.axis[0], factor=block_factor) |
| bx, xi = s[C].split(C.op.axis[1], factor=block_factor) |
| s[C].bind(by, block_y) |
| s[C].bind(bx, block_x) |
| s[C].reorder(by, bx, yi, xi) |
| |
| tyz, yi = s[C].split(yi, nparts=2) |
| ty, yi = s[C].split(yi, nparts=num_thread) |
| txz, xi = s[C].split(xi, nparts=2) |
| tx, xi = s[C].split(xi, nparts=num_thread) |
| s[C].bind(tyz, thread_yz) |
| s[C].bind(txz, thread_xz) |
| s[C].bind(ty, thread_y) |
| s[C].bind(tx, thread_x) |
| s[C].reorder(tyz, txz, ty, tx, yi, xi) |
| s[CC].compute_at(s[C], tx) |
| |
| yo, xo = CC.op.axis |
| ko, ki = s[CC].split(k, factor=8) |
| kt, ki = s[CC].split(ki, factor=1) |
| s[CC].reorder(ko, kt, ki, yo, xo) |
| s[AA].compute_at(s[CC], ko) |
| s[BB].compute_at(s[CC], ko) |
| s[CC].unroll(kt) |
| s[AL].compute_at(s[CC], kt) |
| s[BL].compute_at(s[CC], kt) |
| # Schedule for A's shared memory load |
| ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread) |
| _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread * 4) |
| tx, xi = s[AA].split(xi, nparts=num_thread) |
| s[AA].bind(ty, thread_y) |
| s[AA].bind(tx, thread_x) |
| s[AA].vectorize(xi) |
| # Schedule for B' shared memory load |
| ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread) |
| _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread * 4) |
| tx, xi = s[BB].split(xi, nparts=num_thread) |
| s[BB].bind(ty, thread_y) |
| s[BB].bind(tx, thread_x) |
| s[BB].vectorize(xi) |
| s[AA].double_buffer() |
| s[BB].double_buffer() |
| # correctness |
| def check_device(device): |
| ctx = tvm.context(device, 0) |
| if not ctx.exist: |
| print("Skip because %s is not enabled" % device) |
| return |
| print("Device %s" % device) |
| f = tvm.build(s, [A, B, C], device) |
| # launch the kernel. |
| n, m, l = nn, nn, nn |
| a_np = np.random.uniform(size=(n, l)).astype(A.dtype) |
| b_np = np.random.uniform(size=(m, l)).astype(B.dtype) |
| a = tvm.nd.array(a_np, ctx) |
| b = tvm.nd.array(b_np, ctx) |
| c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) |
| for i in range(2): |
| f(a, b, c) |
| tvm.testing.assert_allclose(c.asnumpy(), np.dot(b_np.T, a_np), rtol=1e-5) |
| |
| num_flops = 2 * nn * nn * nn |
| num_runs = 10 |
| timer_f = f.time_evaluator(f.entry_name, ctx, number=num_runs) |
| t = timer_f(a, b, c).mean |
| GFLOPS = num_flops / (t * 1e3) / 1e6 |
| print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS)) |
| |
| for device in ["cuda", "opencl", "rocm", "nvptx", "vulkan"]: |
| with tvm.transform.PassContext( |
| config={"tir.UnrollLoop": {"auto_max_step": 128, "explicit_unroll": device != "cuda"}} |
| ): |
| check_device(device) |
| |
| |
| if __name__ == "__main__": |
| test_gemm() |