blob: 2b8346d63ee6ffccfbf2af844506e52c6492cd9a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""The build utils in python.
"""
import warnings
import tvm.tir
from tvm.runtime import ndarray
from tvm.ir import container
from tvm.ir import CallingConv
from tvm.ir.transform import PassContext
from tvm.target import codegen
from tvm.te import tensor
from tvm.te import schedule
from tvm.target import Target
def get_binds(args, compact=False, binds=None):
"""Internal function to get binds and arg_list given arguments.
Parameters
----------
args : list of Buffer or Tensor or Var
The argument lists to the function.
compact : bool
If the statement has already bound to a compact buffer.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
Returns
-------
binds: dict
The bind specification
arg_list: list
The list of symbolic buffers of arguments.
"""
binds = {} if binds is None else binds.copy()
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape)
buffer_type = "auto_broadcast" if any_dim and not compact else ""
if x not in binds:
buf = tvm.tir.decl_buffer(
x.shape, dtype=x.dtype, name=x.name, buffer_type=buffer_type
)
binds[x] = buf
arg_list.append(buf)
else:
arg_list.append(binds[x])
elif isinstance(x, schedule.Buffer):
arg_list.append(x)
elif isinstance(x, tvm.tir.Var):
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
return binds, arg_list
def form_irmodule(sch, args, name, binds):
"""According to the given schedule, form a function.
Parameters
----------
sch : tvm.te.schedule.Schedule
The given scheduler to form the raw body
args : list of Buffer or Tensor or Var
The argument lists to the function.
name : str
The name of result function.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
The binds information
Returns
-------
The body formed according to the given schedule
"""
# normalize schedule first
pass_ctx = PassContext.current()
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
compact = schedule.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)
stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds)
func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
func = func.with_attr("global_symbol", name)
if pass_ctx.config.get("tir.noalias", True):
func = func.with_attr("tir.noalias", True)
return tvm.IRModule({name: func})
def lower(sch, args, name="main", binds=None, simple_mode=False):
"""Lowering step before build into target.
Parameters
----------
sch : tvm.te.schedule.Schedule
The schedule to be built
args : list of Buffer or Tensor or Var
The argument lists to the function.
name : str, optional
The name of result function.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
simple_mode : bool, optional
Whether only output simple and compact statement, this will skip
LoopPartition, api wrapper generation and Unrolling.
Returns
-------
m : IRModule or Stmt
The result IRModule, if simple_mode=False
Then the Stmt before make api is returned.
"""
# config setup
pass_ctx = PassContext.current()
instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False))
disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize", False))
add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", [])
lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
# Phase 0
if isinstance(sch, schedule.Schedule):
mod = form_irmodule(sch, args, name, binds)
else:
mod = sch
pass_list = lower_phase0
# Phase 1
pass_list += [
tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
tvm.tir.transform.BF16Legalize(),
tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(),
]
pass_list += lower_phase1
# Phase 2
if not simple_mode:
pass_list += [(tvm.tir.transform.LoopPartition())]
pass_list += [
tvm.tir.transform.VectorizeLoop(not disable_vectorize),
tvm.tir.transform.InjectVirtualThread(),
tvm.tir.transform.InjectDoubleBuffer(),
tvm.tir.transform.StorageRewrite(),
tvm.tir.transform.UnrollLoop(),
]
pass_list += lower_phase2
# Phase 3
pass_list += [
tvm.tir.transform.Simplify(),
tvm.tir.transform.RemoveNoOp(),
]
pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
pass_list += [tvm.tir.transform.HoistIfThenElse()]
pass_list += lower_phase3
# Instrument BoundCheckers
if instrument_bound_checkers:
pass_list += [tvm.tir.transform.InstrumentBoundCheckers()]
optimize = tvm.transform.Sequential(pass_list)
mod = optimize(mod)
return mod
def _build_for_device(input_mod, target, target_host):
"""Build the lowered functions for a device with the given compilation
target.
Parameters
----------
input_mod : IRModule
The schedule to be built.
target : str or :any:`tvm.target.Target`
The target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
The host compilation target.
Returns
-------
fhost : IRModule
The host IRModule.
mdev : tvm.module
A module that contains device code.
"""
target = Target(target)
target_host = Target(target_host)
device_type = ndarray.context(target.kind.name, 0).device_type
mod_mixed = input_mod
mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)
opt_mixed = [tvm.tir.transform.VerifyMemory()]
if len(mod_mixed.functions) == 1:
opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]
if PassContext.current().config.get("tir.detect_global_barrier", False):
opt_mixed += [tvm.tir.transform.ThreadSync("global")]
opt_mixed += [
tvm.tir.transform.ThreadSync("shared"),
tvm.tir.transform.ThreadSync("warp"),
tvm.tir.transform.InferFragment(),
tvm.tir.transform.LowerThreadAllreduce(),
tvm.tir.transform.MakePackedAPI(),
tvm.tir.transform.SplitHostDevice(),
]
mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)
# device optimizations
opt_device = tvm.transform.Sequential(
[
tvm.tir.transform.Filter(
lambda f: "calling_conv" in f.attrs
and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH
),
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.Simplify(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerCustomDatatypes(),
tvm.tir.transform.LowerIntrin(),
]
)
mod_dev = opt_device(mod_mixed)
# host optimizations
opt_host = tvm.transform.Sequential(
[
tvm.tir.transform.Filter(
lambda f: "calling_conv" not in f.attrs
or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH
),
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerCustomDatatypes(),
tvm.tir.transform.LowerIntrin(),
tvm.tir.transform.CombineContextCall(),
]
)
mod_host = opt_host(mod_mixed)
if device_type == ndarray.cpu(0).device_type and target_host == target:
assert len(mod_dev.functions) == 0
if "gpu" in target.keys and len(mod_dev.functions) == 0:
warnings.warn(
"Specified target %s, but cannot find device code, did you do " "bind?" % target
)
rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
return mod_host, rt_mod_dev
def build(inputs, args=None, target=None, target_host=None, name="default_function", binds=None):
"""Build a function with arguments as signature. Code will be generated
for devices coupled with target information.
Parameters
----------
inputs : tvm.te.Schedule, IRModule, or dict of target to IRModule
The schedule to be built
args : list of Buffer or Tensor or Var, optional
The argument lists to the function.
target : str or :any:`tvm.target.Target`, optional
The target and option of the compilation.
target_host : str or :any:`tvm.target.Target` optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
name : str, optional
The name of result function.
binds : dict, optional
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
Returns
-------
ret : tvm.module
A module that combines both host and device code.
Examples
________
There are two typical example uses of this function depending on the type
of the argument `inputs`:
1. it is an IRModule.
.. code-block:: python
n = 2
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.te.create_schedule(C.op)
m = tvm.lower(s, [A, B, C], name="test_add")
rt_mod = tvm.build(m, target="llvm")
2. it is a dict of compilation target to IRModule.
.. code-block:: python
n = 2
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s1 = tvm.te.create_schedule(C.op)
with tvm.target.cuda() as cuda_tgt:
s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
m1 = tvm.lower(s1, [A, B, C], name="test_add1")
m2 = tvm.lower(s2, [A, B, C], name="test_add2")
rt_mod = tvm.build({"llvm": m1, "cuda": m2}, target_host="llvm")
Note
----
See the note on :any:`tvm.target` on target string format.
"""
if isinstance(inputs, schedule.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
input_mod = lower(inputs, args, name=name, binds=binds)
elif isinstance(inputs, (list, tuple, container.Array)):
merged_mod = tvm.IRModule({})
for x in inputs:
merged_mod.update(x)
input_mod = merged_mod
elif isinstance(inputs, tvm.IRModule):
input_mod = inputs
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError("inputs must be Schedule, IRModule or dict of target to IRModule")
if not isinstance(inputs, (dict, container.Map)):
target = Target.current() if target is None else target
target = target if target else "llvm"
target_input_mod = {target: input_mod}
else:
target_input_mod = inputs
for tar, mod in target_input_mod.items():
if not isinstance(tar, (str, Target)):
raise ValueError("The key of inputs must be str or " "Target when inputs is dict.")
if not isinstance(mod, tvm.IRModule):
raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.")
if not target_host:
for tar, _ in target_input_mod.items():
tar = Target(tar)
device_type = ndarray.context(tar.kind.name, 0).device_type
if device_type == ndarray.cpu(0).device_type:
target_host = tar
break
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
mod_host_all = tvm.IRModule({})
device_modules = []
for tar, input_mod in target_input_mod.items():
mod_host, mdev = _build_for_device(input_mod, tar, target_host)
mod_host_all.update(mod_host)
device_modules.append(mdev)
# Generate a unified host module.
rt_mod_host = codegen.build_module(mod_host_all, target_host)
# Import all modules.
for mdev in device_modules:
if mdev:
rt_mod_host.import_module(mdev)
return rt_mod_host