blob: 68d8fe7cef465bc60e86ea39a214b2f42fde1476 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, dangerous-default-value
"""Driver for partitioning and building a Relay module for CUTLASS offload."""
import logging
import os
import multiprocessing
import tvm
from tvm import runtime, relay
from tvm.contrib.nvcc import get_cuda_version
from tvm._ffi.registry import register_func
from .gen_gemm import CutlassGemmProfiler
from .gen_conv2d import CutlassConv2DProfiler
from .library import ConvKind
logger = logging.getLogger("cutlass")
def has_cutlass():
"""Returns true if the CUTLASS custom codegen is available"""
return tvm.get_global_func("relay.ext.cutlass.create_c_source_module", True) is not None
def _get_cutlass_path():
tvm_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../../")
cutlass_path = os.path.join(tvm_root, "3rdparty/cutlass")
assert os.path.exists(
cutlass_path
), """The CUTLASS root directory not found in {}.
Currently, using CUTLASS requires building TVM from source.""".format(
cutlass_path
)
return cutlass_path
def _get_cutlass_compile_options(sm, threads, use_fast_math=False):
cutlass_root = _get_cutlass_path()
cutlass_include = os.path.join(cutlass_root, "include")
cutlass_util_include = os.path.join(cutlass_root, "tools/util/include")
kwargs = {}
kwargs["cc"] = "nvcc"
kwargs["options"] = [
"-c",
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
"-gencode=arch=compute_%d,code=[sm_%d,compute_%d]" % (sm, sm, sm),
"-Xcompiler=-fPIC",
"-Xcompiler=-Wconversion",
"-Xcompiler=-fno-strict-aliasing",
"-O3",
"-std=c++17",
"-I" + cutlass_include,
"-I" + cutlass_util_include,
]
if use_fast_math:
kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID")
cuda_ver = get_cuda_version()
if cuda_ver >= (11, 2):
ncpu = multiprocessing.cpu_count() if threads < 0 else threads
kwargs["options"].append("-t %d" % ncpu)
return kwargs
class OpAnnotator(tvm.relay.ExprVisitor):
"""Annotates partitioned functions with shape and dtype information."""
def __init__(self):
super().__init__()
self.signature = {}
def visit_call(self, call):
op = call.op
if isinstance(op, relay.Function) and "Composite" in op.attrs:
self.signature["op_type"] = op.attrs["Composite"]
for i, arg in enumerate(op.params):
self.signature["arg%d_shape" % i] = arg.checked_type.shape
self.signature["arg%d_dtype" % i] = arg.checked_type.dtype
self.signature["ret_shape"] = op.ret_type.shape
self.signature["ret_dtype"] = op.ret_type.dtype
self.visit(op.body)
if str(op) in ["nn.conv2d", "nn.conv2d_transpose", "nn.conv2d_backward_weight"]:
self.op_attrs = call.attrs
for arg in call.args:
self.visit(arg)
def select_gemm_kernel(
cutlass_profiler,
op_type,
MM,
KK,
NN,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
batched,
find_first_valid,
use_multiprocessing,
):
"""Run CUTLASS profiler to select the best kernel, or return the default one for dynamic
workloads."""
if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]):
out = cutlass_profiler.get_default(
op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32, batched=batched
)
name, cutlass_op_def = out["name"], out["opdef"]
logger.info("Picked the default kernel %s", name)
else:
name, cutlass_op_def, _ = cutlass_profiler.profile(
op_type,
MM,
NN,
KK,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
batched=batched,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
)
if not find_first_valid:
logger.info("The best kernel is %s", name)
else:
logger.info("Picked the first kernel found %s", name)
return name, cutlass_op_def
def handle_batch_matmul(
cutlass_profiler,
op_type,
arg0_shape,
arg1_shape,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
find_first_valid,
use_multiprocessing,
):
"""Profile and select a kernel for batch_matmul op workload."""
MM = arg0_shape[1]
KK = arg0_shape[2]
NN = arg1_shape[1]
name, cutlass_op_def = select_gemm_kernel(
cutlass_profiler,
op_type,
MM,
KK,
NN,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
True,
find_first_valid,
use_multiprocessing,
)
return {
"batch": arg0_shape[0],
"batch_stride_A": arg0_shape[1] * arg0_shape[2],
"batch_stride_B": arg1_shape[1] * arg1_shape[2],
"batch_stride_C": arg0_shape[1] * arg1_shape[1],
"cutlass_op_def": cutlass_op_def,
"cutlass_op_name": name,
"lda": "K",
"ldb": "K",
"ldc": "N",
}
def handle_dense(
cutlass_profiler,
op_type,
arg0_shape,
arg1_shape,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
find_first_valid,
use_multiprocessing,
):
"""Profile and select a kernel for dense op workload."""
MM = arg0_shape[0]
KK = arg0_shape[1]
NN = arg1_shape[0]
name, cutlass_op_def = select_gemm_kernel(
cutlass_profiler,
op_type,
MM,
KK,
NN,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
False,
find_first_valid,
use_multiprocessing,
)
assert "tn_align" in name, "Only supports (row_major, col_major) input layout for now."
return {
"cutlass_op_def": cutlass_op_def,
"cutlass_op_name": name,
"lda": "K",
"ldb": "K",
"ldc": "N",
}
def handle_conv2d(
cutlass_profiler,
op_type,
d_shape,
w_shape,
padding,
strides,
dilation,
out_dtype,
data_dtype,
weight_dtype,
use_3xtf32,
split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
):
"""Profile and select a kernel for conv2d op workload."""
if "conv2d_transpose" in op_type:
conv_kind = ConvKind.Dgrad
elif "backward_weight" in op_type:
conv_kind = ConvKind.Wgrad
else:
conv_kind = ConvKind.Fprop
if any(isinstance(s, tvm.tir.Any) for s in d_shape):
out = cutlass_profiler.get_default(
op_type, out_dtype, data_dtype, weight_dtype, use_3xtf32, conv_kind, strides
)
name, cutlass_op_def = out["name"], out["opdef"]
logger.info("Picked the default kernel %s", name)
else:
name, cutlass_op_def, _ = cutlass_profiler.profile(
op_type,
d_shape,
w_shape,
padding,
strides,
dilation,
out_dtype,
data_dtype,
weight_dtype,
use_3xtf32,
conv_kind,
split_k_slices,
profile_all_alignments,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
)
if not find_first_valid:
logger.info("The best kernel is %s", name)
else:
logger.info("Picked the first kernel found %s", name)
return {
"cutlass_op_def": cutlass_op_def,
"cutlass_op_name": name,
}
def num_cutlass_partitions(mod):
return sum([(1 if "cutlass" in var.name_hint else 0) for var in mod.get_global_vars()])
def tune_cutlass_kernels(
mod,
sm,
use_3xtf32=True,
split_k_slices=[1],
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
tmp_dir="./tmp",
):
"""Given a module partitioned for CUTLASS offloading, profile each workload to select which
kernels to emit.
Parameters
----------
mod : IRModule
The Relay module with cutlass partitions.
sm : int
An integer specifying the compute capability. For example, 75 for Turing and
80 or 86 for Ampere.
use_3xtf32 : bool
Wheter or not use slower but very accurate (compared to tf32) 3xtf32 mode for
fp32 inputs on tensorcore.
split_k_slices : list of int
Split factor candidates for split-K GEMM. If split-K > 1, the GEMM K-loop is computed in
parallel across split-K blocks, and a separate global reduction kernel is launched to
accumulate partial reductions. The profiler will pick the best split-k factor from the
given candidate list. Note that the larger split-K factor requires a larger workspace.
Currently, parallel split-k has been tested only for wgrad. For GEMM and other conv2d
kinds, split_k_slices is ignored.
profile_all_alignments : bool
When True, profile all kernal variants with smaller alignments than the largest possible.
find_first_valid : bool
Whether or not profile all candidate kernels, or stop profiling after
the first applicable kernel is found.
use_multiprocessing : bool
Whether or not compile profiler executables for different kernels in parallel.
tmp_dir : string, optional
A temporary directory where intermediate compiled artifacts will be stored.
Returns
-------
mod : IRModule
The updated module annotated with cutlass profiling information.
num_cutlass_partition : int
The number of partitioned functions created for CUTLASS.
"""
gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir)
conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), tmp_dir)
num_cutlass_partition = 0
for var in mod.get_global_vars():
fun_name = var.name_hint
func = mod[fun_name]
if "cutlass" in fun_name:
num_cutlass_partition += 1
new_func = tune_cutlass_function(
func,
use_3xtf32,
split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
gemm_profiler,
conv2d_profiler,
)
mod.update_func(var, new_func)
return mod, num_cutlass_partition
def tune_cutlass_function(
func,
use_3xtf32,
split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
gemm_profiler,
conv2d_profiler,
):
"""Given a function intended to be offloaded to CUTLASS, profile each workload to select which
kernels to emit.
Parameters
----------
func : IRModule
The Relay Function to tune for.
use_3xtf32 : bool
Wheter or not use slower but very accurate (compared to tf32) 3xtf32 mode for
fp32 inputs on tensorcore.
split_k_slices : list of int
Split factor candidates for split-K GEMM. If split-K > 1, the GEMM K-loop is computed in
parallel accross split-K blocks, and a seperate global reduction kernel is launched to
accumulate partial reductions. The profiler will pick the best split-k factor from the
given candidate list. Note that the larger split-K factor requires a larger workspace.
Currently, parallel split-k has been tested only for wgrad. For GEMM and other conv2d
kinds, split_k_slices is ignored.
profile_all_alignments : bool
When True, profile all kernal variants with smaller alignments than the largest possible.
find_first_valid : bool
Whether or not profile all candidate kernels, or stop profiling after
the first applicable kernel is found.
use_multiprocessing : bool
Whether or not compile profiler executables for different kernels in parallel.
gemm_profiler : CutlassGemmProfiler
Profiler for dense operators. May cache results between tuned functions.
conv2d_profiler : CutlassConv2DProfiler
Profiler for conv2d operators. May cach results between tuned functions.
Returns
-------
annot_func : Function
The input function with attributes capturing the best CUTLASS kernel found by tuning.
"""
annotator = OpAnnotator()
annotator.visit(func)
out_shape = annotator.signature["ret_shape"]
out_dtype = annotator.signature["ret_dtype"]
op_type = annotator.signature["op_type"]
new_attrs = {"op_type": op_type}
new_attrs.update(annotator.signature)
new_attrs.update(func.attrs)
arg0_shape = new_attrs["arg0_shape"]
arg1_shape = new_attrs["arg1_shape"]
arg0_dtype = new_attrs["arg0_dtype"]
arg1_dtype = new_attrs["arg1_dtype"]
if "conv2d" in op_type:
new_attrs["padding"] = annotator.op_attrs.padding
new_attrs["strides"] = annotator.op_attrs.strides
new_attrs["dilation"] = annotator.op_attrs.dilation
if "conv2d_transpose" in op_type:
d_shape = out_shape
w_shape = arg1_shape
elif "conv2d_backward_weight" in op_type:
d_shape = arg1_shape
w_shape = out_shape
else:
d_shape = arg0_shape
w_shape = arg1_shape
new_attrs.update(
handle_conv2d(
conv2d_profiler,
op_type,
d_shape,
w_shape,
annotator.op_attrs.padding,
annotator.op_attrs.strides,
annotator.op_attrs.dilation,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
)
)
elif "batch_matmul" in op_type:
new_attrs.update(
handle_batch_matmul(
gemm_profiler,
op_type,
arg0_shape,
arg1_shape,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
find_first_valid,
use_multiprocessing,
)
)
elif "dense" in op_type:
new_attrs.update(
handle_dense(
gemm_profiler,
op_type,
arg0_shape,
arg1_shape,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
find_first_valid,
use_multiprocessing,
)
)
else:
raise ValueError("%s unsupported composite" % op_type)
new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs)
return relay.Function(
func.params,
func.body,
ret_type=func.ret_type,
type_params=func.type_params,
attrs=new_attrs,
)
@register_func("relay.ext.cutlass.compile_for_cutlass")
def compile_for_cutlass(mod, cutlass_target):
"""Given an IRModule with at least one Compiler='cutlass' Relay function, return a
LibraryModule with all such functions compiled into their PackedFunc-compatible form.
- First runs CUTLASS tuning to decide on the best kernels, which itself requires the
repeated compilation and execution of CUDA code using nvcc. The results of this
is captured as annotation on each relevant function. Kernel performance is cached
overall all functions.
- Then generates a single CSourceModule containing C code implementing all the
Compiler='cutlass' Relay functions, accounting for the tuning done above.
- Then compiles that CSourceModule with the appropriate nvcc arguments to yield
a static .o library. An export_library step will be required on the final runtime
module to link that library into the overall .so library.
See CompileForCutlass in src/relay/backend/contrib/cutlass/codegen.cc for where this
helper function is used to implement the RelayToTIR pass hook for CUTLASS."""
# Recover options from the current 'cutlass' Target
assert cutlass_target.kind.name == "cutlass"
tuning_config = {
key: cutlass_target.attrs.get(key)
for key in [
"sm",
"use_3xtf32",
"split_k_slices",
"profile_all_alignments",
"find_first_valid",
"use_multiprocessing",
]
}
compile_config = {
key: cutlass_target.attrs.get(key) for key in ["sm", "threads", "use_fast_math"]
}
tmp_dir = cutlass_target.attrs.get("tmp_dir")
# Tune
logger.info("Tuning for CUTLASS")
mod, _ = tune_cutlass_kernels(mod, tmp_dir=tmp_dir, **tuning_config)
# Compile
logger.info("Creating CSource module for CUTLASS")
create_c_source_module = tvm._ffi.get_global_func("relay.ext.cutlass.create_c_source_module")
c_module = create_c_source_module(mod)
function_names = c_module.get_function("get_func_names")()
compile_options = _get_cutlass_compile_options(**compile_config)
lib_path = os.path.join(tmp_dir, "cutlass.o")
logger.info("Compiling generated CUTLASS code")
c_module.export_library(lib_path, workspace_dir=tmp_dir, **compile_options)
# Recover static library
logger.info("Loading compiled CUTLASS code")
final_mod = tvm.runtime.load_static_library(lib_path, function_names)
logger.info("Done with CUTLASS compilation")
return final_mod
def finalize_modules(lib, lib_path="compile.so", tmp_dir="./tmp"):
"""Returns lib with any C source, LLVM and static library modules complied and linked in ready
for use by the graph or AOT executors. This method is not specific to CUTLASS, however it does
assume nvcc will be used for final compilation and linking. It is provided here for
convenience.
Parameters
----------
lib : runtime.Module
The output from relay.build.
lib_path : string
The path to a shared library which will be generated as the result of the build process.
tmp_dir : string
A temporary directory where intermediate compiled artifacts will be stored.
Returns
-------
updated_lib : runtime.Module
The updated library with all compilation and linking completed.
"""
lib_path = os.path.join(tmp_dir, lib_path)
lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc")
return runtime.load_module(lib_path)
def finalize_modules_vm(vm_exec, lib_path="compile.so", vmcode_path="vmcode.ro", tmp_dir="./tmp"):
"""Returns vm_exec with any C source, LLVM and static library modules compiled and linked in
ready for use by the VM executor. This method is not specific to CUTLASS, however it does
assume nvcc will be used for final compilation and linking. It is provided here for
convenience.
Parameters
----------
vm_exec : vm.Executable
The output from relay.vm.compile containing compiled host code and kernels.
lib_path : string
The path to a shared library which will be generated as the result of the build process.
vmcode_path : string
The path where the VM bytecode will be serialized to as a side-effect.
tmp_dir : string
A temporary directory where intermediate compiled artifacts will be stored.
Returns
-------
updated_vm_exec : vm.Executable
The updated VM executable with all compilation and linking completed.
"""
code, lib = vm_exec.save()
lib_path = os.path.join(tmp_dir, lib_path)
vmcode_path = os.path.join(tmp_dir, vmcode_path)
lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc")
with open(vmcode_path, "wb") as fo:
fo.write(code)
lib = tvm.runtime.load_module(lib_path)
return tvm.runtime.vm.Executable.load_exec(code, lib)