blob: 1a441a6f03c2a7c6c5dd3f36040cfaf06176713c [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Patterns supported CUTLASS."""
from functools import partial
from tvm import relay
from import Sequential, PassContext
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.op.contrib.register import register_pattern_table # type: ignore
from ...dataflow_pattern import wildcard, is_op, is_constant
def make_gelu_pattern(bias_out, out_dtype="float16"):
mul = is_op("multiply")(bias_out, is_constant() | wildcard())
if out_dtype == "float16":
erf = is_op("cast")(is_op("erf")(is_op("cast")(mul)))
erf = is_op("erf")(mul)
mul_half = is_op("multiply")(erf, is_constant() | wildcard())
add = is_op("add")(mul_half, is_constant() | wildcard())
return is_op("multiply")(add, bias_out)
def make_gemm_pattern(with_bias=True, with_act=None, out_dtype="float16"):
"""Create a pattern for dense op followed by activations."""
data = wildcard()
weight = wildcard()
bias = wildcard()
gemm = is_op("nn.dense")(data, weight)
if with_bias:
add_or_bias_add = is_op("add") | is_op("nn.bias_add")
gemm_out = add_or_bias_add(gemm, bias)
gemm_out = gemm
if with_act is None:
return gemm_out
if isinstance(with_act, str) and with_act == "relu":
return is_op("nn.relu")(gemm_out)
assert isinstance(with_act, str) and with_act == "gelu"
return make_gelu_pattern(gemm_out, out_dtype)
def make_batch_matmul_pattern():
return is_op("nn.batch_matmul")(wildcard(), wildcard())
def make_conv2d_pattern(with_bias=False, with_act=None):
"""Create a pattern for dense op followed by activations."""
data = wildcard()
weight = wildcard()
bias = wildcard()
conv2d = is_op("nn.conv2d")(data, weight)
if with_bias:
add_or_bias_add = is_op("add") | is_op("nn.bias_add")
conv2d_out = add_or_bias_add(conv2d, bias)
conv2d_out = conv2d
if with_act is not None:
if with_act == "relu":
return is_op("nn.relu")(conv2d_out)
if with_act == "sigmoid":
return is_op("sigmoid")(conv2d_out)
if with_act == "silu":
return is_op("multiply")(conv2d_out, is_op("sigmoid")(conv2d_out))
if with_act == "hardswish":
rhs = is_op("divide")(
is_op("clip")(is_op("add")(conv2d_out, is_constant())), is_constant()
return is_op("multiply")(conv2d_out, rhs)
raise ValueError("Unknown activation %s." % with_act)
return conv2d_out
def make_conv2d_transpose_pattern():
return is_op("nn.conv2d_transpose")(wildcard(), wildcard())
def make_conv2d_backward_weight_pattern():
return is_op("nn.conv2d_backward_weight")(wildcard(), wildcard())
def make_residual_block_pattern(tensor_op_out, binary_op="add", with_act="relu"):
"""Add pattern for residual blocks."""
residual_input = wildcard()
binary_out = is_op(binary_op)(tensor_op_out, residual_input) | is_op(binary_op)(
residual_input, tensor_op_out
if with_act is not None and with_act == "relu":
return is_op("nn.relu")(binary_out)
return binary_out
def check_dtype(lhs, rhs):
"""Check if dtypes in the given workload are supported by CUTLASS."""
return (
(lhs.dtype == "float16" and rhs.dtype == "float16")
or (lhs.dtype == "float32" and rhs.dtype == "float32")
or (lhs.dtype in ["int8", "uint8"] and rhs.dtype in ["int8", "uint8"])
def get_root_call(call, root_op_name):
if not isinstance(call, relay.Call):
return None
if str(call.op) == root_op_name:
return call
return get_root_call(call.args[0], root_op_name)
def check_gemm(call):
"""Check if the given dense workload can be offloaded to CUTLASS."""
dense = get_root_call(call, "nn.dense")
lhs = dense.args[0].checked_type
rhs = dense.args[1].checked_type
return check_dtype(lhs, rhs)
def check_batch_matmul(call):
"""Check if the given batch_matmul workload can be offloaded to CUTLASS."""
batch_matmul = get_root_call(call, "nn.batch_matmul")
lhs = batch_matmul.args[0].checked_type
rhs = batch_matmul.args[1].checked_type
transpose_a = batch_matmul.attrs.transpose_a
transpose_b = batch_matmul.attrs.transpose_b
return check_dtype(lhs, rhs) and not transpose_a and transpose_b
def is_depthwise_conv2d(ic, oc, groups):
return ic == oc == groups
def check_conv2d_common(op_name, expected_kernel_layout, call):
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
conv2d = get_root_call(call, op_name)
data_layout = conv2d.attrs.data_layout
kernel_layout = conv2d.attrs.kernel_layout
data = conv2d.args[0].checked_type
weight = conv2d.args[1].checked_type
if (
data_layout != "NHWC"
or kernel_layout != expected_kernel_layout
or not check_dtype(data, weight)
return False
IC = data.shape[3]
OC = weight.shape[0]
return not is_depthwise_conv2d(IC, OC, conv2d.attrs.groups)
def check_conv2d(call):
return check_conv2d_common("nn.conv2d", "OHWI", call)
def check_conv2d_transpose(call):
# conv2d_transpose is implemented as dgrad, needs to swap the roles of C and K
return check_conv2d_common("nn.conv2d_transpose", "IHWO", call)
def check_conv2d_backward_weight(call):
return check_conv2d_common("nn.conv2d_backward_weight", "NHWC", call)
def check_conv2d_residual(call, binary_op):
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
conv2d = get_root_call(call, "nn.conv2d")
if not check_conv2d(call):
return False
residual_binop = get_root_call(call, binary_op)
lhs = residual_binop.args[0]
rhs = residual_binop.args[1]
# residual_input is pattern-matched as a wildcard. Make sure it does not sit between
# residual binary op and the root conv2d of this pattern.
# If the root conv2d is the parent of both lhs and rhs, we should reject this pattern.
if get_root_call(lhs, "nn.conv2d") == conv2d and get_root_call(rhs, "nn.conv2d") == conv2d:
return False
return all(x == y for (x, y) in zip(lhs.checked_type.shape, rhs.checked_type.shape))
def pattern_table():
"""Returns list of triples describing the name, dataflow pattern and predicate for all
the CUTLASS-supported operators."""
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm)
dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"), check_gemm)
dense_bias_gelu_fp16_pat = (
make_gemm_pattern(True, "gelu"),
dense_bias_gelu_fp32_pat = (
make_gemm_pattern(True, "gelu", out_dtype="float32"),
dense_patterns = [
("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul),
conv2d_patterns = [
make_conv2d_pattern(with_bias=True, with_act="hardswish"),
make_conv2d_pattern(with_bias=True, with_act="silu"),
make_conv2d_pattern(with_bias=True, with_act="relu"),
make_conv2d_pattern(with_bias=True, with_act="sigmoid"),
("cutlass.conv2d_bias", make_conv2d_pattern(with_bias=True), check_conv2d),
("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
# For now, no fusion for grad kernels
conv2d_grad_patterns = [
("cutlass.conv2d_transpose", make_conv2d_transpose_pattern(), check_conv2d_transpose),
residual_block_patterns = []
for with_act, postfix in [("relu", "_relu"), (None, "")]:
for name, pat, _ in conv2d_patterns[:-1]:
for bin_op in ["add", "multiply"]:
name + "_residual_" + bin_op + postfix,
make_residual_block_pattern(pat, bin_op, with_act=with_act),
partial(check_conv2d_residual, binary_op=bin_op),
return residual_block_patterns + dense_patterns + conv2d_patterns + conv2d_grad_patterns
def partition_for_cutlass(mod, params=None):
"""Partition the input module into CUTLASS-supported subgraphs."""
if params is not None:
mod["main"] = bind_params_by_name(mod["main"], params)
remove_bn_pass = Sequential(
with PassContext(opt_level=3):
mod = remove_bn_pass(mod)
cutlass_patterns = relay.op.contrib.get_pattern_table("cutlass")
seq = Sequential(
transform.AnnotateTarget(["cutlass"], include_non_call_ops=False),
return seq(mod)