[TVM Basic] Extend generic func with get_packed_func() interface (#9784)

add test_target_temp_strategy unittest.

Co-authored-by: sqing <qing.siqi@intellif.com>
diff --git a/include/tvm/target/generic_func.h b/include/tvm/target/generic_func.h
index a310173..bd49861 100644
--- a/include/tvm/target/generic_func.h
+++ b/include/tvm/target/generic_func.h
@@ -86,7 +86,10 @@
    * \param ret The return value
    */
   TVM_DLL void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue* ret) const;
-
+  /*!
+   * \brief Get the packed function specified for the current target context.
+   */
+  TVM_DLL PackedFunc GetPacked() const;
   /*!
    * \brief Find or register the GenericFunc instance corresponding to the give name
    * \param name The name of the registered GenericFunc
diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py
index 2a62f34..49ac72b 100644
--- a/python/tvm/ir/op.py
+++ b/python/tvm/ir/op.py
@@ -59,6 +59,21 @@
         """
         return _ffi_api.OpGetAttr(self, attr_name)
 
+    def has_attr(self, attr_name):
+        """Check whether the operator has additional attribute.
+
+        Parameters
+        ----------
+        attr_name : str
+            The attribute name.
+
+        Returns
+        -------
+        value : bool
+            Whether the operator has additional attribute
+        """
+        return _ffi_api.OpHasAttr(self, attr_name)
+
     def set_attr(self, attr_name, value, plevel=10):
         """Set attribute about the operator.
 
@@ -157,6 +172,17 @@
         """
         _ffi_api.OpSetAttrsTypeKey(self, key)
 
+    @staticmethod
+    def list_op_names():
+        """List all the op names in the op registry.
+
+        Returns
+        -------
+        value : List[str]
+            The registered op names
+        """
+        return _ffi_api.ListOpNames()
+
 
 def register_op_attr(op_name, attr_key, value=None, level=10):
     """Register an operator property of an operator by name.
diff --git a/python/tvm/target/generic_func.py b/python/tvm/target/generic_func.py
index 932eaa4..7b6f916 100644
--- a/python/tvm/target/generic_func.py
+++ b/python/tvm/target/generic_func.py
@@ -76,6 +76,17 @@
         key_list = [key_list] if isinstance(key_list, str) else key_list
         _ffi_api.GenericFuncRegisterFunc(self, func, key_list, allow_override)
 
+    def get_packed_func(self):
+        """Get the packed function specified for the current target.
+
+        Returns
+        -------
+        func : PackedFunc
+            The function specified for the current target. Return the default
+            function if no specializations match the current target.
+        """
+        return _ffi_api.GenericFuncGetPackedFunc(self)
+
 
 def get_native_generic_func(name):
     """Get a generic function from the global registry. If no
@@ -266,7 +277,7 @@
         return _do_reg
 
     def dispatch_func(func, *args, **kwargs):
-        """The wrapped dispath function"""
+        """The wrapped dispatch function"""
         target = Target.current()
         if target is None:
             return func(*args, **kwargs)
@@ -275,8 +286,19 @@
                 return dispatch_dict[k](*args, **kwargs)
         return func(*args, **kwargs)
 
+    def get_packed_func():
+        """The wrapped to get dispatched function"""
+        target = Target.current()
+        if target is None:
+            return fdefault
+        for k in target.keys:
+            if k in dispatch_dict:
+                return dispatch_dict[k]
+        return fdefault
+
     fdecorate = decorate(fdefault, dispatch_func)
     fdecorate.register = register
     fdecorate.fdefault = fdefault
     fdecorate.dispatch_dict = dispatch_dict
+    fdecorate.get_packed_func = get_packed_func
     return fdecorate
diff --git a/src/ir/op.cc b/src/ir/op.cc
index fac15a7..e0bf561 100644
--- a/src/ir/op.cc
+++ b/src/ir/op.cc
@@ -90,6 +90,10 @@
   return rv;
 });
 
+TVM_REGISTER_GLOBAL("ir.OpHasAttr").set_body_typed([](Op op, String attr_name) -> bool {
+  return Op::HasAttrMap(attr_name);
+});
+
 TVM_REGISTER_GLOBAL("ir.OpSetAttr")
     .set_body_typed([](Op op, String attr_name, runtime::TVMArgValue value, int plevel) {
       auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc
index 4295715..a006567 100644
--- a/src/target/generic_func.cc
+++ b/src/target/generic_func.cc
@@ -118,6 +118,20 @@
   func.CallPacked(args, ret);
 }
 
+PackedFunc GenericFunc::GetPacked() const {
+  auto node = static_cast<const GenericFuncNode*>(get());
+  auto target = Target::Current(true);
+  if (target.defined()) {
+    for (auto& k : target->GetKeys()) {
+      auto iter = node->dispatch_dict_.find(k);
+      if (iter != node->dispatch_dict_.end()) {
+        return iter->second;
+      }
+    }
+  }
+  return node->generic_func_;
+}
+
 TVM_REGISTER_GLOBAL("target.GenericFuncCreate").set_body([](TVMArgs args, TVMRetValue* ret) {
   *ret = GenericFunc(make_object<GenericFuncNode>());
 });
@@ -158,4 +172,9 @@
   generic_func.CallPacked(func_args, ret);
 });
 
+TVM_REGISTER_GLOBAL("target.GenericFuncGetPackedFunc").set_body([](TVMArgs args, TVMRetValue* ret) {
+  GenericFunc generic_func = args[0];
+  *ret = generic_func.GetPacked();
+});
+
 }  // namespace tvm
diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py
index 3a8cba5..199721b 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -63,22 +63,77 @@
 def test_target_dispatch():
     with tvm.target.cuda():
         assert mygeneric(1) == 3
+        assert mygeneric.get_packed_func()(1) == 3
 
     with tvm.target.rocm():
         assert mygeneric(1) == 4
+        assert mygeneric.get_packed_func()(1) == 4
 
     with tvm.target.Target("cuda"):
         assert mygeneric(1) == 3
+        assert mygeneric.get_packed_func()(1) == 3
 
     with tvm.target.arm_cpu():
         assert mygeneric(1) == 11
+        assert mygeneric.get_packed_func()(1) == 11
 
     with tvm.target.Target("metal"):
         assert mygeneric(1) == 3
+        assert mygeneric.get_packed_func()(1) == 3
 
     assert tvm.target.Target.current() is None
 
 
+@tvm.target.override_native_generic_func("test_target_temp_strategy")
+def target_generic(data):
+    # default generic function
+    return data + 1
+
+
+@target_generic.register(["cuda", "gpu"])
+def target_cuda_func(data):
+    return data + 2
+
+
+def temp_target_cuda_func(data):
+    return data + 3
+
+
+def test_target_temp_strategy():
+    class TempStrategy(object):
+        def __init__(self, name, target, fstrategy):
+            generic_fstrategy = tvm.target.get_native_generic_func(name)
+            self.target = target
+            self.name = name
+            self.origin_func = {}
+            with tvm.target.Target(target) as target_obj:
+                for tgt_key in target_obj.keys:
+                    self.origin_func[tgt_key] = generic_fstrategy.get_packed_func()
+                    generic_fstrategy.register(fstrategy, tgt_key, allow_override=True)
+
+        def __enter__(self):
+            return self
+
+        def __exit__(self, typ, value, traceback):
+            generic_fstrategy = tvm.target.get_native_generic_func(self.name)
+            with tvm.target.Target(self.target) as target_obj:
+                for tgt_key in target_obj.keys:
+                    generic_fstrategy.register(
+                        self.origin_func[tgt_key], tgt_key, allow_override=True
+                    )
+
+    with tvm.target.Target("cuda"):
+        assert target_generic(1) == 3
+
+    # The strategy func change to temp_target_cuda_func.
+    with TempStrategy("test_target_temp_strategy", "cuda", temp_target_cuda_func):
+        with tvm.target.Target("cuda"):
+            assert target_generic(1) == 4
+
+    with tvm.target.Target("cuda"):
+        assert target_generic(1) == 3
+
+
 def test_target_string_parse():
     target = tvm.target.Target("cuda -model=unknown -libs=cublas,cudnn")