blob: 42d2134a032f76e6ff5e4cd9b98e669f2837f527 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tag class for TVM operators."""
import warnings
from tvm._ffi.base import decorate
class TagScope(object):
"""Tag scope object to set tag for operators, working as context
manager and decorator both. See also tag_scope.
"""
_current = None
@classmethod
def get_current(cls):
if cls._current:
cls._current.accessed = True
return cls._current
def __init__(self, tag):
self._old_scope = None
self.tag = tag
self.accessed = False
def __enter__(self):
if TagScope._current is not None:
raise ValueError("nested op_tag is not allowed for now")
self._old_scope = TagScope._current
TagScope._current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope is None
if not self.accessed:
warnings.warn("Tag '%s' declared via TagScope was not used." % (self.tag,))
TagScope._current = self._old_scope
def __call__(self, fdecl):
def tagged_fdecl(func, *args, **kwargs):
with self:
return func(*args, **kwargs)
return decorate(fdecl, tagged_fdecl)
def tag_scope(tag):
"""The operator tag scope.
Parameters
----------
tag: str
The tag name.
Returns
-------
tag_scope: TagScope
The tag scope object, which can be used as decorator or
context manger.
Example
-------
.. code-block:: python
n = te.var('n')
m = te.var('m')
l = te.var('l')
A = te.placeholder((n, l), name='A')
B = te.placeholder((m, l), name='B')
k = te.reduce_axis((0, l), name='k')
with tvm.te.tag_scope(tag='matmul'):
C = te.compute((n, m), lambda i, j: te.sum(A[i, k] * B[j, k], axis=k))
# or use tag_scope as decorator
@tvm.te.tag_scope(tag="conv")
def compute_relu(data):
return te.compute(data.shape, lambda *i: tvm.tir.Select(data(*i) < 0, 0.0, data(*i)))
"""
return TagScope(tag)