blob: bc80323b753e0ae1f01e8f1616b4106cab492c84 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import math
import tempfile
import ml_dtypes
import numpy as np
import tvm
import tvm.testing
from tvm import auto_scheduler, relay
from tvm.contrib.cudnn import conv_output_shape
from tvm.contrib.cutlass import (
finalize_modules,
finalize_modules_vm,
has_cutlass,
num_cutlass_partitions,
)
from tvm.contrib.pickle_memoize import memoize
from tvm.relay import op as _op
from tvm.relay.op.contrib.cutlass import partition_for_cutlass
from tvm.relay.transform import FirstOrderGradient, InferType, ToMixedPrecision
from tvm.runtime.vm import VirtualMachine
logging.basicConfig(level=logging.INFO)
def has_cublas():
return tvm.get_global_func("tvm.contrib.cublas.matmul", True) != None
def get_ref_rt_mod(mod, params, target="cuda"):
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
dev = tvm.device(target, 0)
rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
return rt_mod, dev
def get_ref_vm(mod, params, target="cuda"):
with tvm.transform.PassContext(opt_level=3):
vm_exec = relay.vm.compile(mod, target=target, params=params)
code, lib = vm_exec.save()
dev = tvm.device(target, 0)
vm_exec = tvm.runtime.vm.Executable.load_exec(code, lib)
return VirtualMachine(vm_exec, dev), dev
def get_output(rt_mod, names, inputs):
for name, inp in zip(names, inputs):
rt_mod.set_input(name, inp)
rt_mod.run()
return rt_mod.get_output(0).asnumpy()
def get_output_vm(vm, names, inputs):
params = dict(zip(names, inputs))
return vm.invoke("main", **params).numpy()
def get_dense_with_shape(
data_shape, weight_shape, out_dtype="float16", data_dtype="float16", weight_dtype="float16"
):
data = relay.var("data", shape=data_shape, dtype=data_dtype)
weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype)
return relay.nn.dense(data, weight, out_dtype=out_dtype)
def get_dense(M, N, K, out_dtype="float16", data_dtype="float16", weight_dtype="float16"):
return get_dense_with_shape((M, K), (N, K), out_dtype, data_dtype, weight_dtype)
def get_dense_bias(M, N, K, out_dtype="float16"):
dense = get_dense(M, N, K, out_dtype=out_dtype)
bias = relay.var("bias", shape=(N,), dtype=out_dtype)
return relay.nn.bias_add(dense, bias)
def get_dense_bias_relu(M, N, K, out_dtype="float16"):
return relay.nn.relu(get_dense_bias(M, N, K, out_dtype=out_dtype))
def get_dense_bias_gelu(M, N, K, out_dtype="float16"):
bias_add = get_dense_bias(M, N, K, out_dtype)
mul = bias_add * relay.const((1.0 / math.sqrt(2.0)), dtype=out_dtype)
if out_dtype == "float16":
erf = relay.cast(relay.op.erf(relay.cast(mul, "float32")), "float16")
else:
erf = relay.op.erf(mul)
mul_half = erf * relay.const(0.5, dtype=out_dtype)
add = mul_half + relay.const(0.5, dtype=out_dtype)
return add * bias_add
def get_batch_matmul_with_shape(x_shape, y_shape, out_dtype="float16"):
x = relay.var("x", shape=x_shape, dtype="float16")
y = relay.var("y", shape=y_shape, dtype="float16")
return relay.nn.batch_matmul(x, y, out_dtype=out_dtype)
def get_batch_matmul(batch, M, N, K, out_dtype="float16"):
return get_batch_matmul_with_shape((batch, M, K), (batch, N, K), out_dtype="float16")
def get_conv2d_nchw(
d_shape,
w_shape,
padding,
strides=(1, 1),
out_dtype="float16",
data_dtype="float16",
weight_dtype="float16",
):
data = relay.var("data", shape=d_shape, dtype=data_dtype)
weight = relay.var("weight", shape=w_shape, dtype=weight_dtype)
out_channel = w_shape[0]
return relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[2:],
channels=out_channel,
padding=padding,
strides=strides,
out_dtype=out_dtype,
)
def get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype="float16"):
conv2d = get_conv2d_nchw(d_shape, w_shape, padding, out_dtype=out_dtype)
bias = relay.var("bias", shape=(w_shape[0],), dtype=out_dtype)
return relay.nn.bias_add(conv2d, bias)
def silu(x):
return x * relay.sigmoid(x)
def hardswish(x, out_dtype="float16"):
return x * (
relay.clip(x + relay.const(3, dtype=out_dtype), a_min=0, a_max=6)
/ relay.const(6, dtype=out_dtype)
)
def get_conv2d_nchw_bias_relu(d_shape, w_shape, padding, out_dtype="float16"):
return relay.nn.relu(get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype))
def get_conv2d_nchw_bias_sigmoid(d_shape, w_shape, padding, out_dtype="float16"):
return relay.sigmoid(get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype))
def get_conv2d_nchw_bias_silu(d_shape, w_shape, padding, out_dtype="float16"):
conv_out = get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype)
return silu(conv_out)
def get_conv2d_nchw_bias_hardswish(d_shape, w_shape, padding, out_dtype="float16"):
conv_out = get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype)
return hardswish(conv_out, out_dtype)
def get_conv2d_nchw_bias_residual(d_shape, w_shape, padding, out_dtype="float16"):
data = relay.var("data", shape=d_shape, dtype="float16")
weight = relay.var("weight", shape=w_shape, dtype="float16")
bias = relay.var("bias", shape=(w_shape[0],), dtype=out_dtype)
out_channel = w_shape[0]
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[2:],
channels=out_channel,
padding=padding,
out_dtype=out_dtype,
)
bias_add = relay.nn.bias_add(conv2d, bias)
return bias_add, data
def get_conv2d_transpose_nchw(
d_shape,
w_shape,
padding,
output_padding,
strides,
out_dtype="float32",
data_dtype="float32",
weight_dtype="float32",
):
data = relay.var("data", shape=d_shape, dtype=data_dtype)
weight = relay.var("weight", shape=w_shape, dtype=weight_dtype)
out_channel = w_shape[1]
return relay.nn.conv2d_transpose(
data=data,
weight=weight,
kernel_size=w_shape[2:],
channels=out_channel,
padding=padding,
output_padding=output_padding,
strides=strides,
out_dtype=out_dtype,
)
def get_conv2d_backward_weight(
d_shape,
w_shape,
o_shape,
padding,
strides,
out_dtype="float32",
data_dtype="float32",
weight_dtype="float32",
):
grad = relay.var("grad", shape=o_shape, dtype=weight_dtype)
data = relay.var("data", shape=d_shape, dtype=data_dtype)
out_channel = o_shape[1]
return relay.nn.conv2d_backward_weight(
grad=grad,
data=data,
kernel_size=w_shape[2:],
channels=out_channel,
padding=padding,
strides=strides,
out_dtype=out_dtype,
)
def get_dense_transpose_dense(M, N, K, dtype="float16"):
"""
output = nn.dense(_op.transpose(nn.dense(input, weight0), axes=(1, 0)), weight1)
dense0: [M, K] * [N, K] -> [M, N]
transpose: [M, N] -> [N, M]
dense1: [N, M] * [K, M] -> [N, K]
input: [M, K]
weight0: [N, K]
weight1: [K, M]
"""
input_shape = (M, K)
weight0_shape = (N, K)
weight1_shape = (K, M)
input = relay.var("input", shape=input_shape, dtype=dtype)
weight0 = relay.var("weight0", shape=weight0_shape, dtype=dtype)
weight1 = relay.var("weight1", shape=weight1_shape, dtype=dtype)
output0 = relay.nn.dense(input, weight0, out_dtype=dtype)
input1 = _op.transpose(output0, axes=(1, 0))
output = relay.nn.dense(input1, weight1, out_dtype=dtype)
return output
def convert_conv2d_layout(mod, desired_layouts):
with tvm.transform.PassContext(opt_level=3):
seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)])
return seq(mod)
def get_random_ndarray(shape, dtype):
if dtype == "int8":
return np.random.randint(-128, 128, shape).astype(dtype)
elif dtype == "uint8":
return np.random.randint(0, 256, shape).astype(dtype)
return np.random.uniform(-1, 1, shape).astype(dtype)
def profile_and_build(
mod,
params,
sm,
split_k_slices=[1],
tmp_dir="./tmp",
use_fast_math=False,
use_3xtf32=True,
use_ansor=False,
ansor_tuning=False,
):
logging.info("before partitioning:\n%s", mod)
mod = partition_for_cutlass(mod)
logging.info("after partitioning:\n%s", mod)
num_cutlass_partition = num_cutlass_partitions(mod)
host = tvm.target.Target("llvm")
cuda = tvm.target.Target("cuda", host=host)
cutlass = tvm.target.Target(
{
"kind": "cutlass",
"sm": sm,
"use_3xtf32": use_3xtf32,
"split_k_slices": split_k_slices,
"profile_all_alignments": False,
"find_first_valid": True,
"use_multiprocessing": True,
"use_fast_math": use_fast_math,
"tmp_dir": tmp_dir,
},
host=host,
)
if use_ansor:
with tvm.transform.PassContext(
opt_level=3, config={"relay.backend.use_auto_scheduler": True}
):
tasks, task_weights = auto_scheduler.extract_tasks(
mod, params, cuda, include_simple_tasks=True, opt_level=3, other_targets=[cutlass]
)
for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
logging.info(
f"==== Task {idx}: {task.desc} (weight {task_weight} key: {task.workload_key}) ====="
)
logging.info(task.compute_dag)
with tempfile.NamedTemporaryFile() as fp:
log_file = fp.name
# auto-tuning is disabled by default
if ansor_tuning:
measure_ctx = auto_scheduler.LocalRPCMeasureContext(
repeat=3, min_repeat_ms=200, timeout=10
)
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tuner.tune(
auto_scheduler.TuningOptions(
num_measure_trials=100,
runner=measure_ctx.runner,
measure_callbacks=[
auto_scheduler.RecordToFile(log_file),
],
)
)
with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(
opt_level=3,
config={"relay.backend.use_auto_scheduler": True},
):
lib = relay.build(
mod,
target=cuda,
target_host=host,
params=params,
)
else:
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=[cuda, cutlass], params=params)
lib = finalize_modules(lib, "compile.so", tmp_dir)
dev = tvm.device("cuda", 0)
rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
return rt_mod, dev, num_cutlass_partition
def profile_and_build_vm(
mod,
params,
sm,
split_k_slices=[1],
tmp_dir="./tmp",
use_fast_math=False,
use_3xtf32=True,
):
mod = partition_for_cutlass(mod)
num_cutlass_partition = num_cutlass_partitions(mod)
host = tvm.target.Target("llvm")
cuda = tvm.target.Target("cuda", host=host)
cutlass = tvm.target.Target(
{
"kind": "cutlass",
"sm": sm,
"use_3xtf32": use_3xtf32,
"split_k_slices": split_k_slices,
"profile_all_alignments": False,
"find_first_valid": True,
"use_multiprocessing": True,
"use_fast_math": use_fast_math,
"tmp_dir": tmp_dir,
},
host=host,
)
with tvm.transform.PassContext(opt_level=3):
vm_exec = relay.vm.compile(mod, target=[cuda, cutlass], params=params)
vm_exec = finalize_modules_vm(vm_exec, "compile.so", tmp_dir)
dev = tvm.device("cuda", 0)
return VirtualMachine(vm_exec, dev), dev, num_cutlass_partition
def verify_dense(
func,
M,
N,
K,
ref_target="cuda",
sm=80,
atol=1e-5,
rtol=1e-5,
run_benchmark=False,
data_dtype="float16",
weight_dtype="float16",
use_3xtf32=True,
):
assert has_cutlass()
if sm < 80 and data_dtype == "float32":
return
mod = tvm.IRModule.from_expr(func)
typ = relay.transform.InferType()(mod)["main"].body.checked_type
out_dtype = typ.dtype
use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape)
np_data = get_random_ndarray((M, K), data_dtype)
np_weight = get_random_ndarray((N, K), weight_dtype)
np_bias = get_random_ndarray((N,), out_dtype)
params = {"weight": np_weight, "bias": np_bias}
if use_vm:
if ref_target == "cuda" and out_dtype == "float16":
# Uncomment "return" below to see the accuracy difference of static vs dynamic TVM native fp16 dense
# The static one can use a tensorcore schedule, but the dynamic one cannot
rt_mod, dev = get_ref_vm(tvm.IRModule.from_expr(get_dense(M, N, K)), params)
num_partition = 1
logging.warning(
"The reference fp16 dense with dynamic shape using fp16 accumulation has accuracy issues."
)
return
else:
rt_mod, dev, num_partition = profile_and_build_vm(
mod, params, sm, use_3xtf32=use_3xtf32
)
rt_mod_ref, dev = get_ref_vm(mod, params, target=ref_target)
x = tvm.nd.array(np_data, device=dev)
out = get_output_vm(rt_mod, ["data"], [x])
ref_out = get_output_vm(rt_mod_ref, ["data"], [x])
else:
rt_mod_ref, dev = get_ref_rt_mod(mod, params, target=ref_target)
rt_mod, dev, num_partition = profile_and_build(mod, params, sm, use_3xtf32=use_3xtf32)
x = tvm.nd.array(np_data, device=dev)
out = get_output(rt_mod, ["data"], [x])
ref_out = get_output(rt_mod_ref, ["data"], [x])
assert num_partition > 0
np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
if run_benchmark:
print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600))
print("TVM with target %s:" % ref_target, rt_mod_ref.benchmark(dev, number=1, repeat=600))
def verify_batch_matmul(
func, batch, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
):
assert has_cutlass()
mod = tvm.IRModule.from_expr(func)
typ = relay.transform.InferType()(mod)["main"].body.checked_type
use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape)
x_np = np.random.uniform(-1, 1, (batch, M, K)).astype("float16")
y_np = np.random.uniform(-1, 1, (batch, N, K)).astype("float16")
if use_vm:
rt_mod, dev, num_partition = profile_and_build_vm(mod, {}, sm)
rt_mod_ref, dev = get_ref_vm(mod, {}, target=ref_target)
assert num_partition > 0
x = tvm.nd.array(x_np, device=dev)
y = tvm.nd.array(y_np, device=dev)
out = get_output_vm(rt_mod, ["x", "y"], [x, y])
ref_out = get_output_vm(rt_mod_ref, ["x", "y"], [x, y])
else:
rt_mod, dev, num_partition = profile_and_build(mod, {}, sm)
rt_mod_ref, dev = get_ref_rt_mod(mod, {})
assert num_partition > 0
x = tvm.nd.array(x_np, device=dev)
y = tvm.nd.array(y_np, device=dev)
out = get_output(rt_mod, ["x", "y"], [x, y])
ref_out = get_output(rt_mod_ref, ["x", "y"], [x, y])
np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
if run_benchmark:
print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600))
print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600))
M = 96
N = 64
K = 64
@tvm.testing.requires_cutlass
def test_dense():
verify_dense(get_dense(M, N, K), M, N, K)
verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K)
# Test align1 case
verify_dense(get_dense_bias(M, N + 1, K), M, N + 1, K)
# int8
verify_dense(
get_dense(M, N, K, "int32", "int8", "int8"), M, N, K, data_dtype="int8", weight_dtype="int8"
)
dense_fp32 = get_dense(M, N, K, "float32", "float32", "float32")
# fp32
verify_dense(
dense_fp32,
M,
N,
K,
data_dtype="float32",
weight_dtype="float32",
use_3xtf32=False,
sm=75,
)
# tf32
verify_dense(
dense_fp32,
M,
N,
K,
data_dtype="float32",
weight_dtype="float32",
use_3xtf32=False,
atol=1e-2,
rtol=1e-2,
)
# 3xtf32
verify_dense(
dense_fp32,
M,
N,
K,
data_dtype="float32",
weight_dtype="float32",
)
@tvm.testing.requires_cutlass
def test_dense_bias():
verify_dense(get_dense_bias(M, N, K), M, N, K)
verify_dense(get_dense_bias(M, N, K, out_dtype="float32"), M, N, K)
@tvm.testing.requires_cutlass
def test_dense_bias_relu():
verify_dense(get_dense_bias_relu(M, N, K), M, N, K)
verify_dense(get_dense_bias_relu(M, N, K, out_dtype="float32"), M, N, K)
@tvm.testing.requires_cutlass
def test_dense_bias_gelu():
verify_dense(get_dense_bias_gelu(M, N, K), M, N, K, atol=1e-3, rtol=1e-3)
verify_dense(get_dense_bias_gelu(M, N, K, out_dtype="float32"), M, N, K, atol=1e-3, rtol=1e-3)
@tvm.testing.requires_cutlass
def test_dense_dynamic():
data_shape = (relay.Any(), K)
weight_shape = (relay.Any(), K)
if has_cublas():
# TVM native fp16 dense (without tensorcore), using fp16 accum, seems to have accuracy issues
# Use cublas as a reference
verify_dense(
get_dense_with_shape(data_shape, weight_shape),
M,
N,
K,
ref_target="cuda -libs=cublas",
)
verify_dense(
get_dense_with_shape(data_shape, weight_shape, out_dtype="float32"),
M,
N,
K,
atol=1e-4,
rtol=1e-4,
)
@tvm.testing.requires_cutlass
def test_batch_matmul():
batch = 8
verify_batch_matmul(get_batch_matmul(batch, M, N, K), batch, M, N, K)
verify_batch_matmul(get_batch_matmul(batch, M, N, K, out_dtype="float32"), batch, M, N, K)
if has_cublas():
# Test dynamic shape batch_matmul
# AutoTVM does not seem to support it
x_shape = (relay.Any(), relay.Any(), K)
y_shape = (relay.Any(), relay.Any(), K)
verify_batch_matmul(
get_batch_matmul_with_shape(x_shape, y_shape),
batch,
M,
N,
K,
ref_target="cuda -libs=cublas",
)
def verify_conv2d_common(
expr_nchw, # can be dynamic batch
expr_ref, # always static batch
input_names,
inputs,
params,
sm=80,
split_k_slices=[1],
atol=1e-5,
rtol=1e-5,
use_cudnn_ref=False,
run_benchmark=False,
use_fast_math=False,
ref_target="cuda",
use_vm=False,
):
assert has_cutlass()
if sm < 80 and inputs[0].dtype == "float32":
return
mod_nchw = tvm.IRModule.from_expr(expr_nchw)
mod_ref = tvm.IRModule.from_expr(expr_ref)
if use_vm:
profile_and_build_func = profile_and_build_vm
get_output_func = get_output_vm
ref_build_func = get_ref_vm
else:
profile_and_build_func = profile_and_build
get_output_func = get_output
ref_build_func = get_ref_rt_mod
mod_weight_ohwi = convert_conv2d_layout(
mod_nchw,
{
"nn.conv2d": ["NHWC", "OHWI"],
"nn.conv2d_transpose": ["NHWC", "IHWO"],
"nn.conv2d_backward_weight": ["NHWC", "OHWI"],
},
)
rt_mod, _, num_cutlass_partition = profile_and_build_func(
mod_weight_ohwi, params, sm, split_k_slices, use_fast_math=use_fast_math
)
out = get_output_func(rt_mod, input_names, inputs)
assert num_cutlass_partition > 0
if use_cudnn_ref:
rt_mod_ref, dev = ref_build_func(
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "OHWI"]}),
params,
target="cuda -libs=cudnn",
)
else:
rt_mod_ref, dev = ref_build_func(
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}),
params,
target=ref_target,
)
ref_out = get_output_func(rt_mod_ref, input_names, inputs)
if run_benchmark:
print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600))
print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600))
np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
def verify_conv2d(
expr_nchw, # can be dynamic batch
expr_ref, # always static batch
d_shape,
w_shape,
sm=80,
atol=1e-5,
rtol=1e-5,
use_cudnn_ref=False,
run_benchmark=False,
use_fast_math=False,
data_dtype="float16",
weight_dtype="float16",
ref_target="cuda",
use_vm=False,
):
mod_nchw = tvm.IRModule.from_expr(expr_nchw)
typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type
use_vm = use_vm or any(isinstance(s, tvm.tir.Any) for s in typ.shape)
np_data = get_random_ndarray(d_shape, data_dtype)
np_weight = get_random_ndarray(w_shape, weight_dtype)
np_bias = get_random_ndarray((w_shape[0],), typ.dtype)
params = {"weight": np_weight, "bias": np_bias}
split_k_slices = [1]
return verify_conv2d_common(
expr_nchw,
expr_ref,
["data"],
[np_data],
params,
sm,
split_k_slices,
atol,
rtol,
use_cudnn_ref,
run_benchmark,
use_fast_math,
ref_target,
use_vm,
)
def verify_conv2d_backward_weight(
expr_nchw, # can be dynamic batch
expr_ref, # always static batch
grad_shape,
data_shape,
sm=80,
split_k_slices=[1],
atol=1e-5,
rtol=1e-5,
use_cudnn_ref=False,
use_fast_math=False,
grad_dtype="float16",
data_dtype="float16",
ref_target="cuda",
use_vm=False,
):
np_grad = get_random_ndarray(grad_shape, grad_dtype)
np_data = get_random_ndarray(data_shape, data_dtype)
params = {}
input_names = ["grad", "data"]
return verify_conv2d_common(
expr_nchw,
expr_ref,
input_names,
[np_grad, np_data],
params,
sm,
split_k_slices,
atol,
rtol,
use_cudnn_ref,
False,
use_fast_math,
ref_target,
use_vm,
)
@tvm.testing.requires_cutlass
def test_conv2d():
d_shape = (16, 16, 32, 32)
w_shape = (32, 16, 3, 3)
padding = (1, 1)
for IC in [3, 16]:
d_shape = (16, IC, 32, 32)
w_shape = (32, IC, 3, 3)
mod_nchw = get_conv2d_nchw(d_shape, w_shape, padding)
verify_conv2d(
mod_nchw,
mod_nchw,
d_shape,
w_shape,
sm=80,
atol=1e-5,
rtol=1e-5,
use_cudnn_ref=(IC == 3), # The autotvm kernel has an accuracy issue with IC == 3 case
run_benchmark=False,
)
dyn_batch_shape = (relay.Any(),) + d_shape[1:]
mod_nchw = get_conv2d_nchw(d_shape, w_shape, padding)
mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape, padding)
verify_conv2d(
mod_dyn, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
)
for data_dtype, weight_dtype, out_dtype in [
("float32", "float32", "float32"), # 3xtf32
("int8", "int8", "int32"),
("uint8", "int8", "int32"),
]:
expr = get_conv2d_nchw(
d_shape,
w_shape,
padding,
out_dtype=out_dtype,
data_dtype=data_dtype,
weight_dtype=weight_dtype,
)
verify_conv2d(
expr,
expr,
d_shape,
w_shape,
sm=80,
atol=1e-5,
rtol=1e-5,
run_benchmark=False,
data_dtype=data_dtype,
weight_dtype=weight_dtype,
ref_target="llvm",
)
# align1 + int8 case
d_shape = (16, 3, 32, 32)
w_shape = (32, 3, 3, 3)
mod_nchw = get_conv2d_nchw(
d_shape, w_shape, padding, out_dtype="int32", data_dtype="uint8", weight_dtype="int8"
)
verify_conv2d(
mod_nchw,
mod_nchw,
d_shape,
w_shape,
sm=80,
atol=1e-5,
rtol=1e-5,
ref_target="llvm",
data_dtype="uint8",
weight_dtype="int8",
)
@tvm.testing.requires_cutlass
def test_conv2d_fusion():
d_shape = (16, 16, 32, 32)
w_shape = (32, 16, 3, 3)
padding = (1, 1)
mod_nchw = get_conv2d_nchw_bias(d_shape, w_shape, padding)
verify_conv2d(
mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
)
mod_nchw = get_conv2d_nchw_bias_relu(d_shape, w_shape, padding)
verify_conv2d(
mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
)
mod_nchw = get_conv2d_nchw_bias_sigmoid(d_shape, w_shape, padding, out_dtype="float16")
verify_conv2d(
mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
)
verify_conv2d(
mod_nchw,
mod_nchw,
d_shape,
w_shape,
sm=80,
atol=1e-3,
rtol=1e-3,
run_benchmark=False,
use_fast_math=True,
)
mod_nchw = get_conv2d_nchw_bias_sigmoid(d_shape, w_shape, padding, out_dtype="float32")
verify_conv2d(
mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
)
mod_nchw = get_conv2d_nchw_bias_silu(d_shape, w_shape, padding, out_dtype="float32")
verify_conv2d(
mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
)
mod_nchw = get_conv2d_nchw_bias_hardswish(d_shape, w_shape, padding, out_dtype="float16")
verify_conv2d(
mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=5e-2, rtol=5e-2, run_benchmark=False
)
@tvm.testing.requires_cutlass
def test_conv2d_residual_block():
d_shape = (16, 16, 32, 32)
w_shape = (16, 16, 3, 3)
padding = (1, 1)
bias_add, residual_input = get_conv2d_nchw_bias_residual(d_shape, w_shape, padding)
for func, tol in [
(relay.nn.relu(bias_add + residual_input), 1e-5),
(relay.nn.relu(bias_add) + residual_input, 1e-5),
(relay.sigmoid(bias_add) * residual_input, 1e-5),
(relay.nn.relu(silu(bias_add) * residual_input), 1e-5),
# HardSwish requires higher tolerance since vectoring the residual block epilogue
# in cutlass.
# TODO(masahi): Invesitigate this issue
(relay.nn.relu(hardswish(bias_add) + residual_input), 5e-2),
]:
verify_conv2d(func, func, d_shape, w_shape, sm=80, atol=tol, rtol=tol, run_benchmark=False)
@tvm.testing.requires_cutlass
def test_conv2d_transpose():
OC = 8
IC = 16
d_shape = (16, IC, 32, 32)
w_shape = (OC, IC, 3, 3)
padding = (1, 1)
dtype = "float32"
for strides in [(1, 1), (2, 2)]:
o_shape = conv_output_shape(
0, padding, strides, (1, 1), d_shape, (OC, IC, 3, 3), "float32", "float32"
)
output_padding = (1, 1) if strides[0] > 1 else (0, 0)
mod_nchw = get_conv2d_transpose_nchw(
o_shape,
w_shape,
padding,
output_padding,
strides,
out_dtype=dtype,
data_dtype=dtype,
weight_dtype=dtype,
)
verify_conv2d(
mod_nchw,
mod_nchw,
o_shape,
w_shape,
sm=80,
atol=1e-3,
rtol=1e-3,
use_cudnn_ref=False,
run_benchmark=False,
data_dtype=dtype,
weight_dtype=dtype,
)
@tvm.testing.requires_cutlass
def test_conv2d_backward_weight():
OC = 8
IC = 16
d_shape = (16, IC, 32, 32)
w_shape = (OC, IC, 3, 3)
dtype = "float16"
for strides in [(1, 1), (2, 2)]:
o_shape = (16, OC, 32 // strides[0], 32 // strides[1])
padding = (1, 1)
mod_nchw = get_conv2d_backward_weight(
d_shape,
w_shape,
o_shape,
padding,
strides,
out_dtype="float32",
data_dtype=dtype,
weight_dtype=dtype,
)
for split_k_slices in [1, 8]:
verify_conv2d_backward_weight(
mod_nchw,
mod_nchw,
o_shape,
d_shape,
sm=80,
split_k_slices=[split_k_slices],
atol=5e-3,
rtol=5e-3,
use_cudnn_ref=False,
grad_dtype=dtype,
data_dtype=dtype,
)
@tvm.testing.requires_cutlass
def test_conv2d_bwd():
IC = 16
OC = 8
dshape = (16, IC, 32, 32)
wshape = (OC, IC, 3, 3)
padding = (0, 0)
strides = (1, 1)
conv = get_conv2d_nchw(
dshape,
wshape,
padding,
strides=strides,
out_dtype="float32",
data_dtype="float32",
weight_dtype="float32",
)
fwd_mod = InferType()(tvm.IRModule.from_expr(conv))
# Note: large difference in tvm and cutlass Wgrad results if use fp16.
# Cutlass wgrad uses fp32 accumulation even if the output is fp16.
use_fp16 = False
verify_dgrad = False # False to verify wgrad
tol = 1e-5 if verify_dgrad else 1e-4 # Wgrad slightly less accurate
if use_fp16:
fwd_mod = ToMixedPrecision("float16")(fwd_mod)
fwd_bwd_func = FirstOrderGradient()(fwd_mod)["main"]
bwd_func = relay.Function(
fwd_bwd_func.params,
relay.TupleGetItem(relay.TupleGetItem(fwd_bwd_func.body, 1), 0 if verify_dgrad else 1),
)
verify_conv2d(
bwd_func,
bwd_func,
dshape,
wshape,
sm=80,
atol=1e-2 if use_fp16 else tol,
rtol=1e-2 if use_fp16 else tol,
use_cudnn_ref=False,
data_dtype="float32",
weight_dtype="float32",
use_vm=True,
)
def verify_dense_transpose_dense(
func,
M,
N,
K,
ref_target="cuda",
sm=80,
atol=1e-5,
rtol=1e-5,
run_benchmark=False,
dtype="float16",
use_3xtf32=True,
):
assert has_cutlass()
if sm < 80 and dtype == "float32":
return
mod = tvm.IRModule.from_expr(func)
typ = relay.transform.InferType()(mod)["main"].body.checked_type
np_data = get_random_ndarray((M, K), dtype)
np_weight0 = get_random_ndarray((N, K), dtype)
np_weight1 = get_random_ndarray((K, M), dtype)
params = {"weight0": np_weight0, "weight1": np_weight1}
rt_mod_ref, dev = get_ref_rt_mod(mod, params, target=ref_target)
cutlass_rt_mod, dev, num_partition = profile_and_build(
mod,
params,
sm,
use_3xtf32=use_3xtf32,
use_ansor=False,
)
cutlass_ansor_rt_mod, dev, num_partition = profile_and_build(
mod,
params,
sm,
use_3xtf32=use_3xtf32,
use_ansor=True,
)
x = tvm.nd.array(np_data, device=dev)
cutlass_out = get_output(cutlass_rt_mod, ["input"], [x])
cutlass_ansor_out = get_output(cutlass_ansor_rt_mod, ["input"], [x])
ref_out = get_output(rt_mod_ref, ["input"], [x])
assert num_partition > 0
np.testing.assert_allclose(cutlass_out, ref_out, atol=atol, rtol=rtol)
np.testing.assert_allclose(cutlass_ansor_out, ref_out, atol=atol, rtol=rtol)
if run_benchmark:
print("CUTLASS:", cutlass_rt_mod.benchmark(dev, number=1, repeat=600))
print("CUTLASS with Ansor:", cutlass_ansor_rt_mod.benchmark(dev, number=1, repeat=600))
print("TVM with target %s:" % ref_target, rt_mod_ref.benchmark(dev, number=1, repeat=600))
@tvm.testing.requires_cutlass
def test_dense_transpose_dense():
verify_dense_transpose_dense(get_dense_transpose_dense(M, N, K), M, N, K)
def verify_group_gemm(
func_name, M, N, K, num_groups, x_dtype, weight_dtype, out_dtype, use_scale, rtol, atol
):
group_gemm_func = tvm.get_global_func(func_name, allow_missing=True)
if group_gemm_func is None:
print(f"Skipped as {func_name} is not available")
return
@memoize("tvm.contrib.cutlass.test_group_gemm_sm90")
def get_ref_data():
assert M % num_groups == 0
M_per_group = M // num_groups
a_np = get_random_ndarray((M, K), "float16")
b_np = get_random_ndarray((num_groups, N, K), "float16")
indptr_np = np.arange(1, num_groups + 1).astype("int64") * M_per_group
c_np = np.concatenate(
[a_np[i * M_per_group : (i + 1) * M_per_group] @ b_np[i].T for i in range(num_groups)],
axis=0,
)
return a_np, b_np, indptr_np, c_np
def to_numpy_dtype(dtype):
mapping = {"e5m2_float8": ml_dtypes.float8_e5m2, "e4m3_float8": ml_dtypes.float8_e4m3fn}
return mapping.get(dtype, dtype)
a_np, b_np, indptr_np, c_np = get_ref_data()
dev = tvm.cuda(0)
a_nd = tvm.nd.array(a_np.astype(to_numpy_dtype(x_dtype)), device=dev)
b_nd = tvm.nd.array(b_np.astype(to_numpy_dtype(weight_dtype)), device=dev)
c_nd = tvm.nd.empty(c_np.shape, dtype=out_dtype, device=dev)
indptr_nd = tvm.nd.array(indptr_np, device=dev)
workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=dev)
if use_scale:
scale = tvm.nd.array(np.array([1.0], dtype="float32"), device=dev)
group_gemm_func(a_nd, b_nd, indptr_nd, workspace, scale, c_nd)
else:
group_gemm_func(a_nd, b_nd, indptr_nd, workspace, c_nd)
tvm.testing.assert_allclose(c_nd.asnumpy(), c_np, rtol=rtol, atol=atol)
@tvm.testing.requires_cutlass
def test_group_gemm_sm90():
verify_group_gemm(
"cutlass.group_gemm_fp16_sm90",
8,
128,
128,
4,
"float16",
"float16",
"float16",
False,
rtol=1e-3,
atol=1e-3,
)
verify_group_gemm(
"cutlass.group_gemm_e5m2_e5m2_fp16",
8,
16,
16,
4,
"e5m2_float8",
"e5m2_float8",
"float16",
True,
rtol=1e-1,
atol=1,
)
verify_group_gemm(
"cutlass.group_gemm_e4m3_e4m3_fp16",
8,
16,
16,
4,
"e4m3_float8",
"e4m3_float8",
"float16",
True,
rtol=1e-1,
atol=1,
)
verify_group_gemm(
"cutlass.group_gemm_e5m2_e4m3_fp16",
8,
16,
16,
4,
"e5m2_float8",
"e4m3_float8",
"float16",
True,
rtol=1e-1,
atol=1,
)
def verify_gemm(func_name, M, N, K, x_dtype, weight_dtype, out_dtype, scale_value, rtol, atol):
gemm_func = tvm.get_global_func(func_name, allow_missing=True)
if gemm_func is None:
print(f"Skipped as {func_name} is not available")
return
@memoize("tvm.contrib.cutlass.test_fp8_gemm_sm90")
def get_ref_data():
a_np = get_random_ndarray((M, K), "float16")
b_np = get_random_ndarray((N, K), "float16")
c_np = a_np @ b_np.T * scale_value
return a_np, b_np, c_np
def to_numpy_dtype(dtype):
mapping = {"e5m2_float8": ml_dtypes.float8_e5m2, "e4m3_float8": ml_dtypes.float8_e4m3fn}
return mapping.get(dtype, dtype)
a_np, b_np, c_np = get_ref_data()
dev = tvm.cuda(0)
a_nd = tvm.nd.array(a_np.astype(to_numpy_dtype(x_dtype)), device=dev)
b_nd = tvm.nd.array(b_np.astype(to_numpy_dtype(weight_dtype)), device=dev)
c_nd = tvm.nd.empty(c_np.shape, dtype=out_dtype, device=dev)
workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=dev)
scale = tvm.nd.array(np.array([scale_value], dtype="float32"), device=dev)
gemm_func(a_nd, b_nd, workspace, scale, c_nd)
tvm.testing.assert_allclose(c_nd.asnumpy(), c_np, rtol=rtol, atol=atol)
@tvm.testing.requires_cutlass
def test_fp8_gemm_sm90():
verify_gemm(
"cutlass.gemm_e5m2_e5m2_fp16",
8,
16,
16,
"e5m2_float8",
"e5m2_float8",
"float16",
1.5,
rtol=1e-1,
atol=1,
)
verify_gemm(
"cutlass.gemm_e4m3_e4m3_fp16",
8,
16,
16,
"e4m3_float8",
"e4m3_float8",
"float16",
1.5,
rtol=1e-1,
atol=1,
)
verify_gemm(
"cutlass.gemm_e4m3_e4m3_fp16",
32,
16,
16,
"e4m3_float8",
"e4m3_float8",
"float16",
1.5,
rtol=1e-1,
atol=1,
)
verify_gemm(
"cutlass.gemm_e5m2_e4m3_fp16",
8,
16,
16,
"e5m2_float8",
"e4m3_float8",
"float16",
1.5,
rtol=1e-1,
atol=1,
)
if __name__ == "__main__":
tvm.testing.main()