[Relay][AutoTVM] Relay op strategy (#4644)

* relay op strategy

fix lint

bitpack strategy

bitserial_dense (#6)

* update strategy

* address comments

fix a few topi test

Dense strategy (#5)

* dense

* add biforst; remove comments

* address comment

Refactor x86 conv2d_NCHWc (#4)

* Refactor x86 conv2d

* Add x86 depthwise_conv2d_NCHWc

* Add back topi x86 conv2d_nchw

* Merge x86 conv2d_nchw and conv2d_NCHWc

* Minor fix for x86 conv2d

fix more strategy

Add x86 conv2d_NCHWc_int8 strategy (#8)

* Add x86 conv2d_NCHWc_int8 strategy

* Remove contrib_conv2d_nchwc_int8

* Fix generic conv2d_NCHWc for int8

* Fix topi arm_cpu conv2d_NCHWc_int8

update x86 conv2d

enable specify relay ops to be tuned for autotvm

add cuda conv2d strategy

add conv2d strategy for rocm

add conv2d strategy for hls

add conv2d strategy for arm cpu

add conv2d strategy for mali

add conv2d strategy for bifrost

add conv2d strategy for intel graphics

clean up and fix lint

remove template keys from autotvm

remove 2 in the func name

address comments

fix

* fix bugs

* lint

* address comments

* add name to op implement

* Modify topi tests (#9)

* Add pooling, reorg, softmax and vision

* Add lrn

* fix topi test

* fix more topi test

* lint

* address comments

* x

* fix more tests & bugs

* Modify more tests (#10)

* Modify tests for bitserial_conv2d, bitserial_dense, bitserial_conv2d_rasp and bnn

* Minor fix

* More minor fix

* fix more test

* try to update vta using strategy

* fix cpptest

* x

* fix rebase err

* Fix two tests (#11)

* change autotvm log format

* lint

* minor fix

* try fix vta test

* fix rebase err

* tweak

* tmp hack for vta pass

* fix tutorial

* fix

* fix more tutorials

* fix vta tutorial

* minor

* address comments

* fix

* address comments

* fix cpptest

* fix docs

* change data structure name and api

* address comments

* lint

* fix rebase err

* updates

* fix winograd test

* fix doc

* rebase

* upgrade tophub version number

* fix bug

* re-enable vta tsim test after tophub is upgraded

* fix vta test to use the correct args so the config can be found in tophub

Co-authored-by: Yao Wang <kevinthesunwy@gmail.com>
diff --git a/python/vta/ir_pass.py b/python/vta/ir_pass.py
index 36d8e41..0c9b2ea 100644
--- a/python/vta/ir_pass.py
+++ b/python/vta/ir_pass.py
@@ -662,8 +662,12 @@
                                          0, 0,
                                          0, 0, 0))
                 inner = irb.get()
-                args = op.body.body.args
-                res_tensor = op.body.body.func.output(0)
+                # TODO(@tmoreau89): This is only a temporary fix, please take a look.
+                body = op.body.body
+                while isinstance(body, tvm.stmt.IfThenElse):
+                    body = body.then_case
+                args = body.args
+                res_tensor = body.func.output(0)
                 tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
                 inner = tvm.tir.AttrStmt(
                     [dout, res_tensor], 'buffer_bind_scope',
diff --git a/python/vta/top/__init__.py b/python/vta/top/__init__.py
index 7fdf27f..6f62aff 100644
--- a/python/vta/top/__init__.py
+++ b/python/vta/top/__init__.py
@@ -20,8 +20,8 @@
 from . import bitpack
 from .graphpack import graph_pack
 from . import op
-from . import vta_conv2d
-from . import vta_conv2d_transpose
-from . import vta_group_conv2d
-from . import vta_dense
+from .vta_conv2d import conv2d_packed, schedule_conv2d_packed
+from .vta_conv2d_transpose import conv2d_transpose_packed, schedule_conv2d_transpose_packed
+from .vta_group_conv2d import group_conv2d_packed, schedule_group_conv2d_packed
+from .vta_dense import dense_packed, schedule_dense_packed
 from . import util
diff --git a/python/vta/top/bitpack.py b/python/vta/top/bitpack.py
index d4748fa..6e9d57b 100644
--- a/python/vta/top/bitpack.py
+++ b/python/vta/top/bitpack.py
@@ -22,9 +22,8 @@
 import tvm
 from topi import util
 
-from tvm.relay.op.op import register_compute, register_schedule
+from tvm.relay.op.op import register_compute, register_injective_schedule
 from tvm.relay.op.op import register_pattern, OpPattern
-from tvm.relay.op.op import schedule_injective
 
 def bitpack(data, bits, pack_type="int8", name="bitpack"):
     """Packs lowest dimension into format needed by VTA
@@ -86,5 +85,5 @@
     bits = 8 // lanes
     return bitpack(inputs[0], bits, dtype)
 
-register_schedule("bitpack", schedule_injective)
+register_injective_schedule("bitpack")
 register_pattern("bitpack", OpPattern.INJECTIVE)
diff --git a/python/vta/top/op.py b/python/vta/top/op.py
index bf6409c..04e14b1 100644
--- a/python/vta/top/op.py
+++ b/python/vta/top/op.py
@@ -22,19 +22,22 @@
 import topi
 
 from tvm.relay.op import op as reg
-from tvm.relay.op.op import OpPattern
-from tvm.relay.op.nn import _nn
+from tvm.relay.op import strategy as _strategy
+from tvm.relay.op.op import OpPattern, OpStrategy
 
 from .util import is_packed_layout
+from .vta_conv2d import conv2d_packed, schedule_conv2d_packed
+from .vta_conv2d_transpose import conv2d_transpose_packed, schedule_conv2d_transpose_packed
+from .vta_group_conv2d import group_conv2d_packed, schedule_group_conv2d_packed
+from .vta_dense import dense_packed, schedule_dense_packed
 from ..environment import get_env
 
 
 # override to force partition at copy
 reg.register_pattern("copy", OpPattern.INJECTIVE, level=15)
 
-
-@reg.register_compute("clip", level=15)
-def compute_clip(attrs, inputs, output_type, target):
+# add clip vta strategy
+def compute_clip_vta(attrs, inputs, output_type):
     """ Clip operator. """
     x = inputs[0]
     a_min = attrs.a_min
@@ -48,139 +51,79 @@
             x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
     return [x]
 
+def clip_strategy_vta(attrs, inputs, out_type, target):
+    strategy = OpStrategy()
+    strategy.add_implementation(
+        compute_clip_vta,
+        _strategy.wrap_topi_schedule(topi.generic.schedule_injective),
+        name="clip.vta")
+    return strategy
 
-@reg.register_compute("nn.conv2d", level=15)
-def compute_conv2d(attrs, inputs, output_type, target):
-    """ Compute definition of conv2d """
-    padding = topi.util.get_const_tuple(attrs.padding)
-    strides = topi.util.get_const_tuple(attrs.strides)
-    dilation = tuple([int(d) for d in attrs.dilation])
-    groups = attrs.groups
-    layout = attrs.data_layout
-    out_dtype = attrs.out_dtype
+reg.get("clip").get_attr("FTVMStrategy").register(clip_strategy_vta, "vta")
 
-    if target.device_name == "vta":
-        assert dilation == (1, 1), "support for dilation limited to (1, 1)"
-        if is_packed_layout(layout):
-            if groups == 1:
-                assert groups == 1
-                env = get_env()
-                assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
-                assert env.LOG_WGT_WIDTH == 3, "only support 8bit wgt for now"
-                inputs = list(inputs)
-                assert inputs[1].dtype == "int8"
-                return [topi.nn.conv2d(inputs[0],
-                                       inputs[1],
-                                       strides,
-                                       padding,
-                                       dilation,
-                                       layout,
-                                       out_dtype)]
-            return [topi.nn.group_conv2d_nchw(inputs[0],
-                                              inputs[1],
-                                              strides,
-                                              padding,
-                                              dilation,
-                                              groups,
-                                              out_dtype)]
-        # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.Target.current().model):
-            return _nn.compute_conv2d(attrs, inputs, output_type, target)
-
-    # If VTA is not the target, default to _nn def
-    return _nn.compute_conv2d(attrs, inputs, output_type, target)
-
-
-@reg.register_schedule("nn.conv2d", level=15)
-def schedule_conv2d(attrs, outs, target):
-    """ Schedule definition of conv2d """
+@_strategy.conv2d_strategy.register("vta")
+def conv2d_strategy_vta(attrs, inputs, out_type, target):
+    """conv2d vta strategy"""
+    strategy = OpStrategy()
+    kernel = inputs[1]
+    dilation = topi.util.get_const_tuple(attrs.dilation)
     groups = attrs.groups
     layout = attrs.data_layout
 
-    if target.device_name == "vta":
-        if is_packed_layout(layout):
-            target = tvm.target.create(target)
-            assert target.device_name == "vta"
-            if groups == 1:
-                return topi.generic.schedule_conv2d_nchw(outs)
-            return topi.generic.schedule_group_conv2d_nchw(outs)
-        # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.Target.current().model):
-            return _nn.schedule_conv2d(attrs, outs, tvm.target.Target.current())
+    assert dilation == (1, 1), "support for dilation limited to (1, 1)"
+    if is_packed_layout(layout):
+        if groups == 1:
+            env = get_env()
+            assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
+            assert env.LOG_WGT_WIDTH == 3, "only support 8bit wgt for now"
+            assert kernel.dtype == "int8"
 
-    # If VTA is not the target, default to _nn def
-    return _nn.schedule_conv2d(attrs, outs, target)
+            strategy.add_implementation(
+                _strategy.wrap_compute_conv2d(conv2d_packed, True),
+                _strategy.wrap_topi_schedule(schedule_conv2d_packed),
+                name="conv2d_packed.vta")
+        else: # group_conv2d
+            strategy.add_implementation(
+                _strategy.wrap_compute_conv2d(group_conv2d_packed, has_groups=True),
+                _strategy.wrap_topi_schedule(schedule_group_conv2d_packed),
+                name="group_conv2d_packed.vta")
+        return strategy
+
+    # If it's not packed, run on ARM CPU
+    arm_tgt = tvm.target.arm_cpu(target.model)
+    return _strategy.arm_cpu.conv2d_strategy_arm_cpu(attrs, inputs, out_type, arm_tgt)
 
 
-@reg.register_compute("nn.conv2d_transpose", level=15)
-def compute_conv2d_transpose(attrs, inputs, output_type, target):
-    """ 2D convolution algorithm.
-    """
-    padding = topi.util.get_const_tuple(attrs.padding)
-    strides = topi.util.get_const_tuple(attrs.strides)
-    dilation = tuple([int(d) for d in attrs.dilation])
+@_strategy.conv2d_transpose_strategy.register("vta")
+def conv2d_transpose_strategy_vta(attrs, inputs, out_type, target):
+    """conv2d_transpose vta strategy"""
+    dilation = topi.util.get_const_tuple(attrs.dilation)
     layout = attrs.data_layout
-    out_dtype = attrs.out_dtype
+    assert dilation == (1, 1), "support for dilation limited to (1, 1)"
 
-    if target.device_name == "vta":
-        assert dilation == (1, 1), "support for dilation limited to (1, 1)"
-        if is_packed_layout(layout):
-            return [topi.nn.conv2d_transpose_nchw(
-                inputs[0], inputs[1], strides, padding, out_dtype)]
-        # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.Target.current().model):
-            return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target)
+    if is_packed_layout(layout):
+        strategy = OpStrategy()
+        strategy.add_implementation(
+            _strategy.wrap_compute_conv2d_transpose(conv2d_transpose_packed),
+            _strategy.wrap_topi_schedule(schedule_conv2d_transpose_packed),
+            name="conv2d_transpose_packed.vta")
+        return strategy
 
-    # If VTA is not the target, default to _nn def
-    return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target)
+    # If it's not packed, run on ARM CPU
+    arm_tgt = tvm.target.arm_cpu(target.model)
+    return _strategy.arm_cpu.conv2d_transpose_strategy_arm_cpu(attrs, inputs, out_type, arm_tgt)
 
 
-@reg.register_schedule("nn.conv2d_transpose", level=15)
-def schedule_conv2d_transpose(attrs, outputs, target):
-    """ 2D convolution schedule.
-    """
-    layout = attrs.data_layout
-
-    if target.device_name == "vta":
-        if is_packed_layout(layout):
-            return topi.nn.schedule_conv2d_transpose_nchw(outputs)
-        # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.Target.current().model):
-            return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.Target.current())
-
-    # If VTA is not the target, default to _nn def
-    return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.Target.current())
-
-
-@reg.register_compute("nn.dense", level=15)
-def compute_dense(attrs, inputs, out_type, target):
-    """Compute definition of dense"""
-    out_dtype = attrs.out_dtype
-    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
-
-    if target.device_name == "vta":
-        if inputs[0].shape == 4: # this implies the layout is packed
-            target = tvm.target.create(target)
-            return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)]
-        # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.Target.current().model):
-            return _nn.compute_dense(attrs, inputs, out_type, target)
-
-    # If VTA is not the target, default to _nn def
-    return _nn.compute_dense(attrs, inputs, out_type, target)
-
-
-@reg.register_schedule("nn.dense", level=15)
-def schedule_dense(attrs, outs, target):
-    """Schedule definition of dense"""
-    if target.device_name == "vta":
-        if outs[0].shape == 4: # this implies the layout is packed
-            target = tvm.target.create(target)
-            assert target.device_name == "vta"
-            return topi.generic.schedule_dense(outs)
-        # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.Target.current().model):
-            return _nn.schedule_dense(attrs, outs, tvm.target.Target.current())
-
-    # If VTA is not the target, default to _nn def
-    return _nn.schedule_dense(attrs, outs, target)
+@_strategy.dense_strategy.register("vta")
+def dense_strategy_vta(attrs, inputs, out_type, target):
+    """dense vta strategy"""
+    if inputs[0].shape == 4: # this implies the layout is packed
+        strategy = OpStrategy()
+        strategy.add_implementation(
+            _strategy.wrap_compute_dense(dense_packed),
+            _strategy.wrap_topi_schedule(schedule_dense_packed),
+            name="dense_packed.vta")
+        return strategy
+    # If it's not packed, run on ARM CPU
+    arm_tgt = tvm.target.arm_cpu(target.model)
+    return _strategy.x86.dense_strategy_cpu(attrs, inputs, out_type, arm_tgt)
diff --git a/python/vta/top/vta_conv2d.py b/python/vta/top/vta_conv2d.py
index e15f6c1..ba93b05 100644
--- a/python/vta/top/vta_conv2d.py
+++ b/python/vta/top/vta_conv2d.py
@@ -25,15 +25,8 @@
 from .util import is_packed_layout
 from ..environment import get_env
 
-@autotvm.register_topi_compute(topi.nn.conv2d, 'vta', 'direct')
-def _declaration_conv2d(cfg,
-                        data,
-                        kernel,
-                        strides,
-                        padding,
-                        dilation,
-                        layout,
-                        out_dtype):
+@autotvm.register_topi_compute("conv2d_packed.vta")
+def conv2d_packed(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
     """ Packed conv2d function."""
     if not is_packed_layout(layout):
         raise topi.InvalidShapeError()
@@ -69,8 +62,9 @@
 
     return res
 
-@autotvm.register_topi_schedule(topi.generic.schedule_conv2d_nchw, 'vta', 'direct')
-def _schedule_conv2d(cfg, outs):
+@autotvm.register_topi_schedule("conv2d_packed.vta")
+def schedule_conv2d_packed(cfg, outs):
+    """Schedule packed conv2d"""
     assert len(outs) == 1
     output = outs[0]
     const_ops = []
diff --git a/python/vta/top/vta_conv2d_transpose.py b/python/vta/top/vta_conv2d_transpose.py
index a2750dc..a3fd7ac 100644
--- a/python/vta/top/vta_conv2d_transpose.py
+++ b/python/vta/top/vta_conv2d_transpose.py
@@ -26,13 +26,9 @@
 
 from ..environment import get_env
 
-@autotvm.register_topi_compute(topi.nn.conv2d_transpose_nchw, 'vta', 'direct')
-def _declatation_conv2d_transpose(cfg,
-                                  data,
-                                  kernel,
-                                  strides,
-                                  padding,
-                                  out_dtype):
+@autotvm.register_topi_compute("conv2d_transpose_packed.vta")
+def conv2d_transpose_packed(cfg, data, kernel, strides, padding, out_dtype):
+    """Packed conv2d_transpose compute"""
     ishape = get_const_tuple(data.shape)
     kshape = get_const_tuple(kernel.shape)
     b, c_i, i_h, i_w, t_b, t_ci = ishape
@@ -75,8 +71,9 @@
 
     return out
 
-@autotvm.register_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw, 'vta', 'direct')
-def _schedule_conv2d_transpose(cfg, outs):
+@autotvm.register_topi_schedule("conv2d_transpose_packed.vta")
+def schedule_conv2d_transpose_packed(cfg, outs):
+    """Schedule packed conv2d_transpose"""
     assert len(outs) == 1
     output = outs[0]
     ewise_inputs = []
diff --git a/python/vta/top/vta_dense.py b/python/vta/top/vta_dense.py
index 9d6c19c..e239104 100644
--- a/python/vta/top/vta_dense.py
+++ b/python/vta/top/vta_dense.py
@@ -32,12 +32,8 @@
         return True
     return False
 
-@autotvm.register_topi_compute(topi.nn.dense, 'vta', 'direct')
-def _declaration_dense(cfg,
-                       data,
-                       weight,
-                       bias=None,
-                       out_dtype=None):
+@autotvm.register_topi_compute("dense_packed.vta")
+def dense_packed(cfg, data, weight, bias=None, out_dtype=None):
     """Dense function declaration."""
 
     # Make sure that the dense operator is packed
@@ -67,8 +63,8 @@
 
     return res
 
-@autotvm.register_topi_schedule(topi.generic.schedule_dense, 'vta', 'direct')
-def _schedule_dense(cfg, outs):
+@autotvm.register_topi_schedule("dense_packed.vta")
+def schedule_dense_packed(cfg, outs):
     """Packed dense schedule."""
 
     assert len(outs) == 1
diff --git a/python/vta/top/vta_group_conv2d.py b/python/vta/top/vta_group_conv2d.py
index e54637f..aa06c61 100644
--- a/python/vta/top/vta_group_conv2d.py
+++ b/python/vta/top/vta_group_conv2d.py
@@ -24,8 +24,8 @@
 
 from ..environment import get_env
 
-@autotvm.register_topi_compute(topi.nn.group_conv2d_nchw, 'vta', 'direct')
-def packed_group_conv2d(cfg,
+@autotvm.register_topi_compute("group_conv2d_packed.vta")
+def group_conv2d_packed(cfg,
                         data,
                         kernel,
                         strides,
@@ -74,8 +74,8 @@
     return out
 
 
-@autotvm.register_topi_schedule(topi.generic.schedule_group_conv2d_nchw, 'vta', 'direct')
-def schedule_packed_group_conv2d(cfg, outs):
+@autotvm.register_topi_schedule("group_conv2d_packed.vta")
+def schedule_group_conv2d_packed(cfg, outs):
     """ Schedule the packed conv2d.
     """
     assert len(outs) == 1
diff --git a/scripts/tune_resnet.py b/scripts/tune_resnet.py
index b9edc30..cf6f426 100644
--- a/scripts/tune_resnet.py
+++ b/scripts/tune_resnet.py
@@ -246,7 +246,7 @@
     print("Extracting tasks...")
     tasks = extract_from_program(func=relay_prog,
                                  params=params,
-                                 ops=(tvm.relay.op.nn.conv2d,),
+                                 ops=(relay.op.get("nn.conv2d"),),
                                  target=target,
                                  target_host=env.target_host)
 
diff --git a/tests/python/integration/test_benchmark_topi_conv2d.py b/tests/python/integration/test_benchmark_topi_conv2d.py
index af71561..6935e47 100644
--- a/tests/python/integration/test_benchmark_topi_conv2d.py
+++ b/tests/python/integration/test_benchmark_topi_conv2d.py
@@ -20,10 +20,12 @@
 import json
 import os
 
+import pytest
 import numpy as np
 from collections import namedtuple
 
 import tvm
+from tvm import relay
 from tvm import autotvm
 from tvm.contrib import util
 from tvm.contrib.pickle_memoize import memoize
@@ -79,9 +81,13 @@
     if "arm_cpu" in target.keys:
         data_pack = False
         layout = "NCHW"
+        conv2d_fcompute = topi.arm_cpu.conv2d_nchw_spatial_pack
+        conv2d_fschedule = topi.arm_cpu.schedule_conv2d_nchw_spatial_pack
     elif "vta" in target.keys:
         data_pack = True
         layout = "NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN)
+        conv2d_fcompute = vta.top.conv2d_packed
+        conv2d_fschedule = vta.top.schedule_conv2d_packed
 
     # Derive shapes depending upon packing
     a_shape = (wl.batch, wl.in_filter, wl.height, wl.width)
@@ -101,18 +107,24 @@
     data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
     kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
     bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype)
+    padding = relay.nn.get_pad_tuple2d((wl.hpad, wl.wpad))
 
     # Define base computation schedule
     with target:
-        res = topi.nn.conv2d(
-            data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), (1, 1),
-            layout, env.acc_dtype)
+        if data_pack:
+            res = conv2d_fcompute(
+                data, kernel, (wl.hstride, wl.wstride), padding, (1, 1),
+                layout, env.acc_dtype)
+        else:
+            res = conv2d_fcompute(
+                data, kernel, (wl.hstride, wl.wstride), padding, (1, 1),
+                env.acc_dtype)
         res = topi.right_shift(res, 8)
         res = topi.add(res, bias)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
         # Derive base schedule
-        s = topi.generic.schedule_conv2d_nchw([res])
+        s = conv2d_fschedule([res])
         if print_ir:
             print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
 
@@ -222,7 +234,8 @@
 
     return correct, cost, stats
 
-def test_conv2d(device="vta"):
+@pytest.mark.parametrize("device", ["vta", "arm_cpu"])
+def test_conv2d(device):
     def _run(env, remote):
         if device == "vta":
             target = env.target
diff --git a/tests/python/integration/test_benchmark_topi_conv2d_transpose.py b/tests/python/integration/test_benchmark_topi_conv2d_transpose.py
index d729fa5..2d96a73 100644
--- a/tests/python/integration/test_benchmark_topi_conv2d_transpose.py
+++ b/tests/python/integration/test_benchmark_topi_conv2d_transpose.py
@@ -20,10 +20,12 @@
 import json
 import os
 
+import pytest
 import numpy as np
 from collections import namedtuple
 
 import tvm
+from tvm import relay
 from tvm import autotvm
 from tvm.contrib import util
 from tvm.contrib.pickle_memoize import memoize
@@ -80,14 +82,18 @@
     if "arm_cpu" in target.keys:
         data_pack = False
         layout = "NCHW"
+        fcompute = topi.arm_cpu.conv2d_transpose_nchw
+        fschedule = topi.arm_cpu.schedule_conv2d_transpose_nchw
     elif "vta" in target.keys:
         data_pack = True
         layout = "NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN)
+        fcompute = vta.top.conv2d_transpose_packed
+        fschedule = vta.top.schedule_conv2d_transpose_packed
 
     # Derive shapes depending upon packing
 
     a_shape = (wl.batch, wl.in_filter, wl.height, wl.width)
-    w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)
+    w_shape = (wl.in_filter, wl.out_filter, wl.hkernel, wl.wkernel)
     if data_pack:
         data_shape = (wl.batch//env.BATCH, wl.in_filter//env.BLOCK_IN,
                       wl.height, wl.width, env.BATCH, env.BLOCK_IN)
@@ -98,16 +104,17 @@
         kernel_shape = w_shape
     data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
     kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
+    padding = relay.nn.get_pad_tuple2d((wl.hpad, wl.wpad))
 
     # Define base computation schedule
     with target:
-        res = topi.nn.conv2d_transpose_nchw(
-            data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), env.acc_dtype)
+        res = fcompute(
+            data, kernel, (wl.hstride, wl.wstride), padding, env.acc_dtype)
         res = topi.right_shift(res, env.WGT_WIDTH)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
         # Derive base schedule
-        s = topi.generic.schedule_conv2d_transpose_nchw([res])
+        s = fschedule([res])
         if print_ir:
             print(vta.lower(s, [data, kernel, res], simple_mode=True))
 
@@ -210,7 +217,8 @@
 
     return correct, cost, stats
 
-def test_conv2d_transpose(device="vta"):
+@pytest.mark.parametrize("device", ["vta", "arm_cpu"])
+def test_conv2d_transpose(device):
     def _run(env, remote):
         if device == "vta":
             target = env.target
@@ -227,5 +235,5 @@
     vta.testing.run(_run)
 
 if __name__ == "__main__":
-    # test_conv2d_transpose(device="arm_cpu")
+    test_conv2d_transpose(device="arm_cpu")
     test_conv2d_transpose(device="vta")
diff --git a/tests/python/integration/test_benchmark_topi_dense.py b/tests/python/integration/test_benchmark_topi_dense.py
index b0ee2f5..a0acdc3 100644
--- a/tests/python/integration/test_benchmark_topi_dense.py
+++ b/tests/python/integration/test_benchmark_topi_dense.py
@@ -63,21 +63,25 @@
                       env.BATCH, env.BLOCK_IN)
         kernel_shape = (out_feat//env.BLOCK_OUT, in_feat//env.BLOCK_IN,
                         env.BLOCK_OUT, env.BLOCK_IN)
+        fcompute = vta.top.dense_packed
+        fschedule = vta.top.schedule_dense_packed
     else:
         data_shape = a_shape
         kernel_shape = w_shape
+        fcompute = topi.x86.dense_nopack
+        fschedule = topi.x86.schedule_dense_nopack
     data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
     kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
 
     # Define base computation schedule
     with target:
-        res = topi.nn.dense(
-            data, kernel, out_dtype=env.acc_dtype)
+        res = fcompute(
+            data, kernel, None, env.acc_dtype)
         res = topi.right_shift(res, 8)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
         # Derive base schedule
-        s = topi.generic.schedule_dense([res])
+        s = fschedule([res])
         if print_ir:
             print(vta.lower(s, [data, kernel, res], simple_mode=True))
 
diff --git a/tests/python/integration/test_benchmark_topi_group_conv2d.py b/tests/python/integration/test_benchmark_topi_group_conv2d.py
index 7bba244..31fef49 100644
--- a/tests/python/integration/test_benchmark_topi_group_conv2d.py
+++ b/tests/python/integration/test_benchmark_topi_group_conv2d.py
@@ -20,10 +20,12 @@
 import json
 import os
 
+import pytest
 import numpy as np
 from collections import namedtuple
 
 import tvm
+from tvm import relay
 from tvm import autotvm
 from tvm.contrib import util
 import topi
@@ -75,9 +77,13 @@
     if "arm_cpu" in target.keys:
         data_pack = False
         layout = "NCHW"
+        fcompute = topi.nn.group_conv2d_nchw
+        fschedule = topi.generic.schedule_group_conv2d_nchw
     elif "vta" in target.keys:
         data_pack = True
         layout = "NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN)
+        fcompute = vta.top.group_conv2d_packed
+        fschedule = vta.top.schedule_group_conv2d_packed
 
     # Derive shapes depending upon packing
     CI_G = wl.in_filter // wl.groups
@@ -98,17 +104,19 @@
     data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
     kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
     bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype)
+    padding = relay.nn.get_pad_tuple2d((wl.hpad, wl.wpad))
+
     # Define base computation schedule
     with target:
-        res = topi.nn.group_conv2d_nchw(
-            data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), (1, 1),
+        res = fcompute(
+            data, kernel, (wl.hstride, wl.wstride), padding, (1, 1),
             wl.groups, env.acc_dtype)
         res = topi.right_shift(res, 8)
         res = topi.add(res, bias)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
         # Derive base schedule
-        s = topi.generic.schedule_group_conv2d_nchw([res])
+        s = fschedule([res])
         if print_ir:
             print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
 
@@ -219,7 +227,8 @@
 
     return correct, cost, stats
 
-def test_conv2d(device="vta"):
+@pytest.mark.parametrize("device", ["vta", "arm_cpu"])
+def test_conv2d(device):
     def _run(env, remote):
         if device == "vta":
             target = env.target
diff --git a/tutorials/autotvm/tune_relay_vta.py b/tutorials/autotvm/tune_relay_vta.py
index 94fba3d..a20b8ec 100644
--- a/tutorials/autotvm/tune_relay_vta.py
+++ b/tutorials/autotvm/tune_relay_vta.py
@@ -295,7 +295,7 @@
 
 
 def register_vta_tuning_tasks():
-    from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args
+    from tvm.autotvm.task import TaskExtractEnv
 
     @tvm.tag_scope(tag=topi.tag.ELEMWISE)
     def my_clip(x, a_min, a_max):
@@ -309,20 +309,19 @@
     # init autotvm env to register VTA operator
     TaskExtractEnv()
 
-    @autotvm.task.register("topi_nn_conv2d", override=True)
+    @autotvm.register_customized_task("conv2d_packed.vta")
     def _topi_nn_conv2d(*args, **kwargs):
         assert not kwargs, "Do not support kwargs in template function call"
-        args = deserialize_args(args)
         A, W = args[:2]
 
         with tvm.target.vta():
-            res = topi.nn.conv2d(*args, **kwargs)
+            res = vta.top.conv2d_packed(*args, **kwargs)
             res = topi.right_shift(res, 8)
             res = my_clip(res, 0, 127)
             res = topi.cast(res, "int8")
 
         if tvm.target.Target.current().device_name == 'vta':
-            s = topi.generic.schedule_conv2d_nchw([res])
+            s = vta.top.schedule_conv2d_packed([res])
         else:
             s = tvm.create_schedule([res.op])
         return s, [A, W, res]
@@ -356,10 +355,13 @@
     mod = tvm.IRModule.from_expr(relay_prog)
     tasks = autotvm.task.extract_from_program(mod,
                                               params=params,
-                                              ops=(tvm.relay.op.nn.conv2d, ),
+                                              ops=(relay.op.get("nn.conv2d"),),
                                               target=target,
                                               target_host=env.target_host)
 
+    # filter out non-packed conv2d task
+    tasks = list(filter(lambda t: len(t.args[0][1]) > 4, tasks))
+
     # We should have extracted 10 convolution tasks
     assert len(tasks) == 10
     print("Extracted {} conv2d tasks:".format(len(tasks)))