blob: de9f8403de2aa3179d7f4f3275f95795f18758d5 [file] [log] [blame]
"""Tag class for TVM operators."""
import warnings
from ._ffi.base import decorate
class TagScope(object):
"""Tag scope object to set tag for operators, working as context
manager and decorator both. See also tag_scope.
"""
_current = None
@classmethod
def get_current(cls):
if cls._current:
cls._current.accessed = True
return cls._current
def __init__(self, tag):
self._old_scope = None
self.tag = tag
self.accessed = False
def __enter__(self):
if TagScope._current is not None:
raise ValueError("nested op_tag is not allowed for now")
self._old_scope = TagScope._current
TagScope._current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope is None
if not self.accessed:
warnings.warn("Tag '%s' declared via TagScope was not used." % (self.tag,))
TagScope._current = self._old_scope
def __call__(self, fdecl):
def tagged_fdecl(func, *args, **kwargs):
with self:
return func(*args, **kwargs)
return decorate(fdecl, tagged_fdecl)
def tag_scope(tag):
"""The operator tag scope.
Parameters
----------
tag: str
The tag name.
Returns
-------
tag_scope: TagScope
The tag scope object, which can be used as decorator or
context manger.
Example
-------
.. code-block:: python
n = tvm.var('n')
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B')
k = tvm.reduce_axis((0, l), name='k')
with tvm.tag_scope(tag='matmul'):
C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k))
# or use tag_scope as decorator
@tvm.tag_scope(tag="conv")
def compute_relu(data):
return tvm.compute(data.shape, lambda *i: tvm.select(data(*i) < 0, 0.0, data(*i)))
"""
return TagScope(tag)