blob: e2cea9b31a9d1fa271bb1d6ad4018827ba29c36d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Matrix exponential example.
This is an example for matrix exponential,
which calculates the following recursion formula
```math
X[t] = dot(X[t-1], W)
```
"""
import tvm
from tvm import te
import time
import os
import argparse
from tvm.contrib import nvcc
import numpy as np
# Quick knobs
TASK = "matexp"
USE_MANUAL_CODE = False
PERSIST_KERNEL = True
DETECT_GLOBAL_BARRIER = PERSIST_KERNEL
SKIP_CHECK = False
@tvm.register_func
def tvm_callback_cuda_compile(code):
"""Use nvcc compiler for better perf."""
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
def write_code(code, fname):
with open(fname, "w") as f:
f.write(code)
@tvm.register_func
def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"):
os.mkdir("perf")
write_code(code, "perf/%s_generated.cu" % TASK)
if USE_MANUAL_CODE:
code = open("perf/%s_manual.cu" % TASK).read()
return code
def rnn_matexp():
n_num_step = 128
n_num_hidden = 1152
n_batch_size = 4
detect_global_barrier = DETECT_GLOBAL_BARRIER
num_step = te.var("num_step")
num_hidden = tvm.runtime.convert(n_num_hidden)
batch_size = tvm.runtime.convert(n_batch_size)
num_thread_y = 8
num_thread_x = 16 * 3
num_sm = 24
Whh = te.placeholder((num_hidden, num_hidden), name="Whh")
s_init = te.compute((1, batch_size, num_hidden), lambda _, i, j: 1.0, name="init")
s_state = te.placeholder((num_step, batch_size, num_hidden))
kh = te.reduce_axis((0, num_hidden), name="kh")
s_update = te.compute(
(num_step, batch_size, num_hidden),
lambda t, i, j: te.sum(s_state[t - 1, i, kh] * Whh[kh, j], axis=kh),
name="update",
)
s_scan = tvm.te.scan(s_init, s_update, s_state)
# schedule
s = te.create_schedule(s_scan.op)
CL = s_update
SS = s.cache_read(s_state, "shared", [CL])
SL = s.cache_read(SS, "local", [CL])
WhhL = s.cache_read(Whh, "local", [CL])
ko, ki = s[CL].split(s[CL].op.reduce_axis[0], nparts=num_thread_y)
CLF = s.rfactor(CL, ko)
block_x = te.thread_axis((0, num_sm), "blockIdx.x")
thread_x = te.thread_axis((0, num_thread_x), "threadIdx.x")
thread_y = te.thread_axis((0, num_thread_y), "threadIdx.y")
if PERSIST_KERNEL:
s[s_scan.op].env_threads([block_x, thread_y, thread_x])
bx, xi = s[s_init].split(s_init.op.axis[2], nparts=num_sm)
tx, xi = s[s_init].split(xi, nparts=num_thread_x)
s[s_init].bind(bx, block_x)
s[s_init].bind(tx, thread_x)
bx, xi = s[s_update].split(s[CL].op.axis[2], nparts=num_sm)
tx, xi = s[s_update].split(xi, nparts=num_thread_x)
s[s_update].bind(bx, block_x)
s[s_update].bind(tx, thread_x)
s[CL].bind(s[CL].op.reduce_axis[0], thread_y)
s[CLF].compute_at(s[CL], s[CL].op.reduce_axis[0])
# Duplicate store predicate.
s[CL].set_store_predicate(thread_y.equal(0))
if PERSIST_KERNEL:
s[WhhL].compute_at(s[s_scan], thread_x)
s[WhhL].unroll(WhhL.op.axis[0])
else:
s[WhhL].compute_at(s[CLF], CLF.op.axis[3])
kr, ki = s[CLF].split(CLF.op.reduce_axis[0], nparts=1)
ko, ki = s[CLF].split(ki, factor=4)
s[SS].compute_at(s[CLF], kr)
s[SL].compute_at(s[CLF], ko)
xo, xi = s[SS].split(SS.op.axis[2], factor=num_thread_x * num_thread_y * 3)
ty, xi = s[SS].split(xi, nparts=num_thread_y)
tx, xi = s[SS].split(xi, nparts=num_thread_x)
s[SS].bind(ty, thread_y)
s[SS].bind(tx, thread_x)
def check_device(target):
with tvm.transform.PassContext(
config={
"tir.UnrollLoop": {
"auto_max_step": 128,
},
"tir.detect_global_barrier": detect_global_barrier,
}
):
f = tvm.build(s, [s_scan, Whh], target)
ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
# launch the kernel.
res_np = np.zeros((n_num_step, n_batch_size, n_num_hidden)).astype("float32")
Whh_np = np.zeros((n_num_hidden, n_num_hidden)).astype("float32")
Whh_np[:] = 2.0 / n_num_hidden
Whh_np[:, n_num_hidden // 2 :] = 0
res_a = tvm.nd.array(res_np, ctx)
Whh_a = tvm.nd.array(Whh_np, ctx)
# Skip first pass as it is compilation
f(res_a, Whh_a)
ctx.sync()
# measure time cost of second step.
tstart = time.time()
f(res_a, Whh_a)
ctx.sync()
tgap = time.time() - tstart
print("Time cost=%g" % tgap)
# correctness
if not SKIP_CHECK:
res_gpu = res_a.asnumpy()
res_cmp = np.ones_like(res_np).astype("float64")
Whh_np = Whh_np.astype("float64")
for t in range(1, n_num_step):
res_cmp[t][:] = np.dot(res_cmp[t - 1], Whh_np)
for i in range(n_num_step):
for j in range(n_num_hidden):
if abs(res_cmp[i, 0, j] - res_gpu[i, 0, j]) > 1e-5:
print("%d, %d: %g vs %g" % (i, j, res_cmp[i, 0, j], res_gpu[i, 0, j]))
tvm.testing.assert_allclose(res_gpu, res_cmp, rtol=1e-3)
check_device("cuda")
if __name__ == "__main__":
rnn_matexp()