blob: 4c33d36d69b56b611f85616e154bcb5870e6dbfc [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""VTA specific buildin for runtime."""
import tvm
from . import ir_pass
from .environment import get_env
def lift_coproc_scope(x):
"""Lift coprocessings cope to the """
x = ir_pass.lift_alloc_to_scope_begin(x)
x = tvm.tir.ir_pass.LiftAttrScope(x, "coproc_scope", False)
return x
def early_rewrite(stmt):
"""Try to do storage rewrite in early pass."""
try:
return tvm.tir.ir_pass.StorageRewrite(stmt)
except tvm.error.TVMError:
return stmt
def build_config(debug_flag=0, **kwargs):
"""Build a build config for VTA.
Parameters
----------
debug_flag : int
The dbeug flag to be passed.
kwargs : dict
Additional configurations.
Returns
-------
build_config: BuildConfig
The build config that can be used in TVM.
Example
--------
.. code-block:: python
# build a vta module.
with vta.build_config():
vta_module = tvm.build(s, ...)
"""
env = get_env()
def add_debug(stmt):
debug = tvm.tir.call_extern(
"int32", "VTASetDebugMode",
env.dev.command_handle,
debug_flag)
return tvm.tir.stmt_seq(debug, stmt)
pass_list = [(0, ir_pass.inject_conv2d_transpose_skip),
(1, ir_pass.inject_dma_intrin),
(1, ir_pass.inject_skip_copy),
(1, ir_pass.annotate_alu_coproc_scope),
(1, lambda x: tvm.tir.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)),
(1, lift_coproc_scope),
(1, ir_pass.inject_coproc_sync),
(1, early_rewrite)]
if debug_flag:
pass_list.append((1, add_debug))
pass_list.append((2, ir_pass.inject_alu_intrin))
pass_list.append((3, tvm.tir.ir_pass.LowerStorageAccessInfo))
pass_list.append((3, ir_pass.fold_uop_loop))
pass_list.append((3, ir_pass.cpu_access_rewrite))
return tvm.target.build_config(add_lower_pass=pass_list, **kwargs)
def lower(*args, **kwargs):
"""Thin wrapper of tvm.lower
This wrapper automatically applies VTA's build_config
if there is no user specified build_config in context.
See Also
--------
tvm.lower : The original TVM's lower function
"""
cfg = tvm.target.BuildConfig.current()
if not cfg.add_lower_pass:
with build_config():
return tvm.lower(*args, **kwargs)
return tvm.lower(*args, **kwargs)
def build(*args, **kwargs):
"""Thin wrapper of tvm.build
This wrapper automatically applies VTA's build_config
if there is no user specified build_config in context.
See Also
--------
tvm.build : The original TVM's build function
"""
cfg = tvm.target.BuildConfig.current()
if not cfg.add_lower_pass:
with build_config():
return tvm.build(*args, **kwargs)
return tvm.build(*args, **kwargs)