[TVM Basic] Extend generic func with get_packed_func() interface (#9784)
add test_target_temp_strategy unittest.
Co-authored-by: sqing <qing.siqi@intellif.com>
diff --git a/include/tvm/target/generic_func.h b/include/tvm/target/generic_func.h
index a310173..bd49861 100644
--- a/include/tvm/target/generic_func.h
+++ b/include/tvm/target/generic_func.h
@@ -86,7 +86,10 @@
* \param ret The return value
*/
TVM_DLL void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue* ret) const;
-
+ /*!
+ * \brief Get the packed function specified for the current target context.
+ */
+ TVM_DLL PackedFunc GetPacked() const;
/*!
* \brief Find or register the GenericFunc instance corresponding to the give name
* \param name The name of the registered GenericFunc
diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py
index 2a62f34..49ac72b 100644
--- a/python/tvm/ir/op.py
+++ b/python/tvm/ir/op.py
@@ -59,6 +59,21 @@
"""
return _ffi_api.OpGetAttr(self, attr_name)
+ def has_attr(self, attr_name):
+ """Check whether the operator has additional attribute.
+
+ Parameters
+ ----------
+ attr_name : str
+ The attribute name.
+
+ Returns
+ -------
+ value : bool
+ Whether the operator has additional attribute
+ """
+ return _ffi_api.OpHasAttr(self, attr_name)
+
def set_attr(self, attr_name, value, plevel=10):
"""Set attribute about the operator.
@@ -157,6 +172,17 @@
"""
_ffi_api.OpSetAttrsTypeKey(self, key)
+ @staticmethod
+ def list_op_names():
+ """List all the op names in the op registry.
+
+ Returns
+ -------
+ value : List[str]
+ The registered op names
+ """
+ return _ffi_api.ListOpNames()
+
def register_op_attr(op_name, attr_key, value=None, level=10):
"""Register an operator property of an operator by name.
diff --git a/python/tvm/target/generic_func.py b/python/tvm/target/generic_func.py
index 932eaa4..7b6f916 100644
--- a/python/tvm/target/generic_func.py
+++ b/python/tvm/target/generic_func.py
@@ -76,6 +76,17 @@
key_list = [key_list] if isinstance(key_list, str) else key_list
_ffi_api.GenericFuncRegisterFunc(self, func, key_list, allow_override)
+ def get_packed_func(self):
+ """Get the packed function specified for the current target.
+
+ Returns
+ -------
+ func : PackedFunc
+ The function specified for the current target. Return the default
+ function if no specializations match the current target.
+ """
+ return _ffi_api.GenericFuncGetPackedFunc(self)
+
def get_native_generic_func(name):
"""Get a generic function from the global registry. If no
@@ -266,7 +277,7 @@
return _do_reg
def dispatch_func(func, *args, **kwargs):
- """The wrapped dispath function"""
+ """The wrapped dispatch function"""
target = Target.current()
if target is None:
return func(*args, **kwargs)
@@ -275,8 +286,19 @@
return dispatch_dict[k](*args, **kwargs)
return func(*args, **kwargs)
+ def get_packed_func():
+ """The wrapped to get dispatched function"""
+ target = Target.current()
+ if target is None:
+ return fdefault
+ for k in target.keys:
+ if k in dispatch_dict:
+ return dispatch_dict[k]
+ return fdefault
+
fdecorate = decorate(fdefault, dispatch_func)
fdecorate.register = register
fdecorate.fdefault = fdefault
fdecorate.dispatch_dict = dispatch_dict
+ fdecorate.get_packed_func = get_packed_func
return fdecorate
diff --git a/src/ir/op.cc b/src/ir/op.cc
index fac15a7..e0bf561 100644
--- a/src/ir/op.cc
+++ b/src/ir/op.cc
@@ -90,6 +90,10 @@
return rv;
});
+TVM_REGISTER_GLOBAL("ir.OpHasAttr").set_body_typed([](Op op, String attr_name) -> bool {
+ return Op::HasAttrMap(attr_name);
+});
+
TVM_REGISTER_GLOBAL("ir.OpSetAttr")
.set_body_typed([](Op op, String attr_name, runtime::TVMArgValue value, int plevel) {
auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc
index 4295715..a006567 100644
--- a/src/target/generic_func.cc
+++ b/src/target/generic_func.cc
@@ -118,6 +118,20 @@
func.CallPacked(args, ret);
}
+PackedFunc GenericFunc::GetPacked() const {
+ auto node = static_cast<const GenericFuncNode*>(get());
+ auto target = Target::Current(true);
+ if (target.defined()) {
+ for (auto& k : target->GetKeys()) {
+ auto iter = node->dispatch_dict_.find(k);
+ if (iter != node->dispatch_dict_.end()) {
+ return iter->second;
+ }
+ }
+ }
+ return node->generic_func_;
+}
+
TVM_REGISTER_GLOBAL("target.GenericFuncCreate").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = GenericFunc(make_object<GenericFuncNode>());
});
@@ -158,4 +172,9 @@
generic_func.CallPacked(func_args, ret);
});
+TVM_REGISTER_GLOBAL("target.GenericFuncGetPackedFunc").set_body([](TVMArgs args, TVMRetValue* ret) {
+ GenericFunc generic_func = args[0];
+ *ret = generic_func.GetPacked();
+});
+
} // namespace tvm
diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py
index 3a8cba5..199721b 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -63,22 +63,77 @@
def test_target_dispatch():
with tvm.target.cuda():
assert mygeneric(1) == 3
+ assert mygeneric.get_packed_func()(1) == 3
with tvm.target.rocm():
assert mygeneric(1) == 4
+ assert mygeneric.get_packed_func()(1) == 4
with tvm.target.Target("cuda"):
assert mygeneric(1) == 3
+ assert mygeneric.get_packed_func()(1) == 3
with tvm.target.arm_cpu():
assert mygeneric(1) == 11
+ assert mygeneric.get_packed_func()(1) == 11
with tvm.target.Target("metal"):
assert mygeneric(1) == 3
+ assert mygeneric.get_packed_func()(1) == 3
assert tvm.target.Target.current() is None
+@tvm.target.override_native_generic_func("test_target_temp_strategy")
+def target_generic(data):
+ # default generic function
+ return data + 1
+
+
+@target_generic.register(["cuda", "gpu"])
+def target_cuda_func(data):
+ return data + 2
+
+
+def temp_target_cuda_func(data):
+ return data + 3
+
+
+def test_target_temp_strategy():
+ class TempStrategy(object):
+ def __init__(self, name, target, fstrategy):
+ generic_fstrategy = tvm.target.get_native_generic_func(name)
+ self.target = target
+ self.name = name
+ self.origin_func = {}
+ with tvm.target.Target(target) as target_obj:
+ for tgt_key in target_obj.keys:
+ self.origin_func[tgt_key] = generic_fstrategy.get_packed_func()
+ generic_fstrategy.register(fstrategy, tgt_key, allow_override=True)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, typ, value, traceback):
+ generic_fstrategy = tvm.target.get_native_generic_func(self.name)
+ with tvm.target.Target(self.target) as target_obj:
+ for tgt_key in target_obj.keys:
+ generic_fstrategy.register(
+ self.origin_func[tgt_key], tgt_key, allow_override=True
+ )
+
+ with tvm.target.Target("cuda"):
+ assert target_generic(1) == 3
+
+ # The strategy func change to temp_target_cuda_func.
+ with TempStrategy("test_target_temp_strategy", "cuda", temp_target_cuda_func):
+ with tvm.target.Target("cuda"):
+ assert target_generic(1) == 4
+
+ with tvm.target.Target("cuda"):
+ assert target_generic(1) == 3
+
+
def test_target_string_parse():
target = tvm.target.Target("cuda -model=unknown -libs=cublas,cudnn")