blob: 3b62edd1a9782889c26e6565d3fe8e6da10e370d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, invalid-name
"""VTA specific buildin for runtime."""
import tvm
from . import transform
from .environment import get_env
def EarlyRewrite():
"""Try to do storage rewrite in early pass."""
def _transform(mod, ctx):
try:
return tvm.tir.transform.StorageRewrite()(mod)
except tvm.error.TVMError:
return mod
return tvm.transform.module_pass(_transform, opt_level=0, name="tir.vta.EarlyRewrite")
def build_config(debug_flag=0, **kwargs):
"""Build a build config for VTA.
Parameters
----------
debug_flag : int
The dbeug flag to be passed.
kwargs : dict
Additional configurations.
Returns
-------
build_config: tvm.transform.PassContext
The build config that can be used in TVM.
Example
--------
.. code-block:: python
# build a vta module.
with vta.build_config():
vta_module = tvm.build(s, ...)
"""
env = get_env()
@tvm.tir.transform.prim_func_pass(opt_level=0)
def add_debug(f, *_):
debug = tvm.tir.call_extern("int32", "VTASetDebugMode", env.dev.command_handle, debug_flag)
return f.with_body(tvm.tir.stmt_seq(debug, f.body))
pass_list = [
(0, transform.InjectConv2DTransposeSkip()),
(1, transform.InjectDMAIntrin()),
(1, transform.InjectSkipCopy()),
(1, transform.AnnotateALUCoProcScope()),
(1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")),
(1, transform.LiftAllocToScopeBegin()),
(1, tvm.tir.transform.LiftAttrScope("coproc_scope")),
(1, transform.InjectCoProcSync()),
(1, EarlyRewrite()),
]
if debug_flag:
pass_list.append((1, add_debug))
pass_list.append((2, transform.InjectALUIntrin()))
pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo()))
pass_list.append((3, transform.FoldUopLoop()))
pass_list.append((3, transform.CPUAccessRewrite()))
config = {"tir.add_lower_pass": pass_list}
if kwargs.get("config"):
config.update(kwargs[config])
del kwargs["config"]
return tvm.transform.PassContext(config=config, **kwargs)
def lower(*args, **kwargs):
"""Thin wrapper of tvm.lower
This wrapper automatically applies VTA's build_config
if there is no user specified build_config in context.
See Also
--------
tvm.lower : The original TVM's lower function
"""
pass_ctx = tvm.transform.PassContext.current()
if not pass_ctx.config.get("add_lower_pass"):
with build_config():
return tvm.lower(*args, **kwargs)
return tvm.lower(*args, **kwargs)
def build(*args, **kwargs):
"""Thin wrapper of tvm.build
This wrapper automatically applies VTA's build_config
if there is no user specified build_config in context.
See Also
--------
tvm.build : The original TVM's build function
"""
pass_ctx = tvm.transform.PassContext.current()
if not pass_ctx.config.get("tir.add_lower_pass"):
with build_config():
return tvm.build(*args, **kwargs)
return tvm.build(*args, **kwargs)
# Register key ops
tvm.ir.register_op_attr("tir.vta.coproc_sync", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
tvm.ir.register_op_attr("tir.vta.coproc_dep_push", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
tvm.ir.register_op_attr("tir.vta.coproc_dep_pop", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
tvm.ir.register_op_attr("tir.vta.uop_push", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
tvm.ir.register_op_attr("tir.vta.uop_push", "TGlobalSymbol", "VTAUopPush")
tvm.ir.register_op_attr("tir.vta.command_handle", "TGlobalSymbol", "VTATLSCommandHandle")
tvm.ir.register_op_attr("tir.vta.command_handle", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)