removing nnvm dep from VTA sources (#4419)

diff --git a/python/vta/__init__.py b/python/vta/__init__.py
index a78db30..70c003c 100644
--- a/python/vta/__init__.py
+++ b/python/vta/__init__.py
@@ -15,11 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""VTA Package is a TVM backend extension to support VTA hardwares
+"""VTA Package is a TVM backend extension to support VTA hardware.
 
-Besides the compiler toolchain.
-It also include utility functions to
-configure the hardware Environment and  access remote through RPC
+Besides the compiler toolchain, it also includes utility functions to
+configure the hardware environment and access remote device through RPC.
 """
 from __future__ import absolute_import as _abs
 
@@ -31,9 +30,8 @@
 
 __version__ = "0.1.0"
 
-# do not import nnvm/topi when running vta.exec.rpc_server
+# do not import topi when running vta.exec.rpc_server
 # to maintain minimum dependency on the board
 if sys.argv[0] not in ("-c", "-m"):
     from . import top
     from .build_module import build_config, lower, build
-    from . import graph
diff --git a/python/vta/graph.py b/python/vta/graph.py
deleted file mode 100644
index 1a626ee..0000000
--- a/python/vta/graph.py
+++ /dev/null
@@ -1,333 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Graph transformation specific to accelerator.
-
-This module provide specific NNVM graph transformations
-to transform a generic NNVM graph to a version that can
-be executed on accelerator.
-"""
-
-import nnvm
-
-from nnvm.compiler import graph_attr, graph_util
-
-
-def _pack_batch_channel(data, dshape, bfactor, cfactor):
-    """Pack the data channel dimension.
-    """
-    assert dshape[0] % bfactor == 0
-    assert dshape[1] % cfactor == 0
-    data = nnvm.sym.reshape(data,
-                            shape=(dshape[0] // bfactor, bfactor,
-                                   dshape[1] // cfactor, cfactor,
-                                   dshape[2], dshape[3]))
-    data = nnvm.sym.transpose(
-        data, axes=(0, 2, 4, 5, 1, 3))
-    return data
-
-
-def _unpack_batch_channel(data, old_shape):
-    """Unpack the data channel dimension.
-    """
-    data = nnvm.sym.transpose(data, axes=(0, 4, 1, 5, 2, 3))
-    data = nnvm.sym.reshape(data, shape=old_shape)
-    return data
-
-
-def _pack_weight(data, dshape, cfactor):
-    """Pack the weight into packed format.
-    """
-    assert len(dshape) == 4
-    assert dshape[0] % cfactor == 0
-    assert dshape[1] % cfactor == 0
-    data = nnvm.sym.reshape(data,
-                            shape=(dshape[0] // cfactor, cfactor,
-                                   dshape[1] // cfactor, cfactor,
-                                   dshape[2], dshape[3]))
-    data = nnvm.sym.transpose(
-        data, axes=(0, 2, 4, 5, 1, 3))
-    return data
-
-
-def _pack_bias(data, dshape, bfactor, cfactor):
-    """Pack the bias parameter.
-    """
-    assert len(dshape) == 3
-    assert dshape[0] % cfactor == 0
-    data = nnvm.sym.reshape(data,
-                            shape=(dshape[0] // cfactor,
-                                   cfactor, dshape[1],
-                                   dshape[2], 1))
-    data = nnvm.sym.transpose(
-        data, axes=(0, 2, 3, 4, 1))
-    # broadcast batch dimension to bfactor
-    data = nnvm.sym.broadcast_to(
-        data,
-        shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor))
-    return data
-
-
-def _get_shape(sym, shape_dict):
-    """Get the shape of a node.
-    """
-    return graph_util.infer_shape(
-        nnvm.graph.create(sym), **shape_dict)[1][0]
-
-def clean_conv_fuse(graph):
-    """Cleanup the convolution's later fuse stages
-
-    Parameters
-    ----------
-    graph : Graph
-        Input graph
-
-    Returns
-    -------
-    graph : Graph
-        Optimized graph
-    """
-    def _clean_entry(entry):
-        node, flag = entry
-        if flag:
-            node = nnvm.symbol.clip(node, a_max=127, a_min=-127)
-            node = nnvm.symbol.cast(node, dtype="int8")
-            # Use copy as a hint to block conv2d schedules
-            node = nnvm.symbol.copy(node)
-            flag = False
-        return node, flag
-
-    gidx = graph.index
-    ref_count = {}
-    # count reference of each node
-    for nid, node in enumerate(gidx.nodes):
-        ref_count[nid] = 0
-        for elem in node["inputs"]:
-            ref_count[elem[0]] += 1
-    # construction remap
-    # entry_id->(new_node, conv_fuse)
-    # need_fold: bool indicates if we need fold
-    node_map = {}
-
-    for nid, node in enumerate(gidx.nodes):
-        children = [node_map[e[0]] for e in node["inputs"]]
-        attrs = node.get("attrs", {})
-        node_name = node["name"]
-        op_name = node["op"]
-        get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)(
-            *c, name=n_n, **a)
-
-        new_entry = None
-        if op_name == "null":
-            new_entry = (nnvm.symbol.Variable(node_name), False)
-        elif op_name in ("cast", "clip"):
-            if children[0][1]:
-                new_entry = children[0]
-            else:
-                new_entry = (
-                    get_clone([children[0][0]], op_name, node_name, attrs),
-                    False)
-        elif op_name == "conv2d" and attrs["out_dtype"] == "int32":
-            data, weight = children
-            data = _clean_entry(data)
-            new_node = nnvm.sym.conv2d(
-                data[0], weight[0], name=node_name, **attrs)
-            new_entry = (new_node, True)
-        elif op_name in ("__lshift_scalar__", "__rshift_scalar__", "relu"):
-            new_entry = (
-                get_clone([children[0][0]], op_name, node_name, attrs),
-                children[0][1])
-        elif op_name in ("broadcast_add", "broadcast_mul"):
-            rhs = children[1][0]
-            lhs, _ = _clean_entry(children[0])
-            lhs = nnvm.sym.cast(lhs, dtype="int32")
-            rhs = nnvm.sym.cast(rhs, dtype="int32")
-            new_entry = (
-                get_clone([lhs, rhs], op_name, node_name, attrs),
-                False)
-
-        if new_entry is None:
-            inputs = [_clean_entry(x) for x in children]
-            new_entry = (
-                get_clone([x[0] for x in inputs], op_name, node_name, attrs),
-                False)
-        if ref_count[nid] > 1:
-            new_entry = _clean_entry(new_entry)
-        node_map[nid] = new_entry
-
-    assert len(graph.index.output_entries) == 1
-    ret = node_map[graph.index.output_entries[0][0]][0]
-    ret = nnvm.graph.create(ret)
-    return ret
-
-def clean_cast(graph):
-    """
-    Move the casts to early part of graph,
-    remove uncessary clip operations when possible.
-    """
-    gidx = graph.index
-    node_map = {}
-
-    def _clean_cast(node, target_type):
-        op_name = node.attr("op_name")
-        if op_name == "cast":
-            return _clean_cast(node.get_children(), target_type)
-        if op_name == "relu":
-            data, has_clip = _clean_cast(
-                node.get_children(), target_type)
-            data = nnvm.sym.relu(data)
-            return data, has_clip
-        return nnvm.sym.cast(node, dtype=target_type), False
-
-    for nid, node in enumerate(gidx.nodes):
-        children = [node_map[e[0]] for e in node["inputs"]]
-        attrs = node.get("attrs", {})
-        node_name = node["name"]
-        op_name = node["op"]
-        get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)(
-            *c, name=n_n, **a)
-
-        if op_name == "null":
-            new_node = nnvm.symbol.Variable(node_name)
-        elif op_name == "cast":
-            dtype = attrs["dtype"]
-            new_node, _ = _clean_cast(children[0], dtype)
-        elif op_name == "conv2d" and attrs["out_dtype"] == "int32":
-            data, weight = children
-            data, _ = _clean_cast(data, "int8")
-            weight, _ = _clean_cast(weight, "int8")
-            new_node = nnvm.sym.conv2d(
-                data, weight, name=node_name, **attrs)
-        elif op_name == "elemwise_add":
-            lhs, rhs = children
-            rhs = nnvm.sym.cast(rhs, dtype="int8")
-            new_node = nnvm.sym.elemwise_add(lhs, rhs)
-        else:
-            new_node = get_clone(children, op_name, node_name, attrs)
-        node_map[nid] = new_node
-
-    assert len(graph.index.output_entries) == 1
-    ret = node_map[graph.index.output_entries[0][0]]
-    ret = nnvm.graph.create(ret)
-    return ret
-
-
-def pack(graph, shape_dict, bfactor, cfactor, start_name=None):
-    """Pack the graph into batch&channel packed format.
-
-    Parameters
-    ----------
-    graph : Graph
-       The input graph.
-
-    shape_dict : dict of str to shapex
-       The input shape.
-
-    bfactor : int
-       The packing factor in batch
-
-    cfactor : int
-       The packing factor in channel
-
-    start_name: str, optional
-       Start name start packing from certain known node.
-
-    Returns
-    -------
-    graph : Graph
-        The transformed graph.
-    """
-    graph = graph_attr.set_shape_inputs(graph, shape_dict)
-    graph = graph.apply("InferShape")
-    shape = graph.json_attr("shape")
-    gidx = graph.index
-    node_map = {}
-    dset = set()
-    counter = 0
-    start_pack = False
-
-    for nid, node in enumerate(gidx.nodes):
-        children = [node_map[e[0]] for e in node["inputs"]]
-        ishape = [shape[gidx.entry_id(e)] for e in node["inputs"]]
-        oshape = shape[gidx.entry_id(nid, 0)]
-        attrs = node.get("attrs", {})
-        node_name = node["name"]
-        op_name = node["op"]
-        get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)(
-            *c, name=n_n, **a)
-
-        if op_name == "null":
-            new_node = nnvm.symbol.Variable(node_name)
-            if start_name and node_name == start_name:
-                start_pack = True
-                new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
-        elif op_name == "max_pool2d":
-            assert not start_pack
-            start_pack = True
-            new_node = get_clone(children, op_name, node_name, attrs)
-            new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
-        elif op_name == "global_avg_pool2d":
-            if start_pack:
-                start_pack = False
-                children[0] = _unpack_batch_channel(children[0], ishape[0])
-                new_node = getattr(nnvm.symbol, op_name)(
-                    *children, name=node_name, **attrs)
-            else:
-                new_node = get_clone(children, op_name, node_name, attrs)
-        elif op_name == "conv2d" and attrs["out_dtype"] == "int32":
-            if start_pack:
-                attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor)
-                attrs["kernel_layout"] = "OIHW%do%di" % (cfactor, cfactor)
-                data, weight = children
-                weight = _pack_weight(weight, ishape[1], cfactor)
-                new_node = nnvm.sym.conv2d(
-                    data, weight, name=node_name, **attrs)
-            elif counter == 1:
-                attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor)
-                attrs["kernel_layout"] = "OIHW%do%di" % (cfactor, cfactor)
-                data, weight = children
-                data = _pack_batch_channel(data, ishape[0], bfactor, cfactor)
-                weight = _pack_weight(weight, ishape[1], cfactor)
-                new_node = nnvm.sym.conv2d(
-                    data, weight, name=node_name, **attrs)
-                new_node = _unpack_batch_channel(new_node, oshape)
-                counter = counter + 1
-            else:
-                new_node = get_clone(children, op_name, node_name, attrs)
-        elif op_name.startswith("broadcast"):
-            if start_pack:
-                assert len(ishape[1]) == 3
-                children[1] = _pack_bias(children[1], ishape[1], bfactor, cfactor)
-                new_node = getattr(nnvm.symbol, op_name)(
-                    *children, name=node_name, **attrs)
-            else:
-                new_node = get_clone(children, op_name, node_name, attrs)
-        elif op_name.startswith("elementwise_add"):
-            new_node = get_clone(children, op_name, node_name, attrs)
-        else:
-            new_node = get_clone(children, op_name, node_name, attrs)
-            dset.add(op_name)
-        node_map[nid] = new_node
-
-    assert len(graph.index.output_entries) == 1
-    ret = node_map[graph.index.output_entries[0][0]]
-    if start_pack:
-        oshape = shape[graph.index.output_entries[0][0]]
-        ret = _unpack_batch_channel(ret, oshape)
-    graph = nnvm.graph.create(ret)
-    graph = graph_attr.set_shape_inputs(graph, shape_dict)
-    graph = graph.apply("InferShape")
-    return graph
diff --git a/python/vta/top/__init__.py b/python/vta/top/__init__.py
index f269113..09d3101 100644
--- a/python/vta/top/__init__.py
+++ b/python/vta/top/__init__.py
@@ -24,8 +24,3 @@
 from . import vta_conv2d_transpose
 from . import vta_dense
 from . import util
-
-# NNVM is deprecated for VTA
-# from . import nnvm_bitpack
-# from .nnvm_graphpack import nnvm_graph_pack
-# from . import nnvm_op
diff --git a/python/vta/top/nnvm_bitpack.py b/python/vta/top/nnvm_bitpack.py
deleted file mode 100644
index 0dc2413..0000000
--- a/python/vta/top/nnvm_bitpack.py
+++ /dev/null
@@ -1,86 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=unused-argument
-"""Bit packing operators"""
-from __future__ import absolute_import as _abs
-
-import tvm
-from topi import util
-
-from nnvm.top import registry as reg, OpPattern
-from nnvm.top.tensor import _fschedule_broadcast
-
-def bitpack(data, bits, pack_type="int8", name="bitpack"):
-    """Packs lowest dimension into format needed by VTA
-    Parameters
-    ----------
-    pack_axis : int
-        index of the axis to pack in data
-    bit_axis : int
-        index of axis to place bit axis in resulting packed data
-    Returns
-    -------
-    packed : Tensor
-        The packed tensor.
-    """
-    shape_vec = list(data.shape)
-    if pack_type == 'int8':
-        data_width = 8
-    elif pack_type == 'int16':
-        data_width = 16
-    elif pack_type == 'int32':
-        data_width = 32
-    else:
-        raise RuntimeError("Unknown pack type %s" % pack_type)
-    assert data_width % bits == 0
-    lanes = data_width // bits
-
-    # Data must be in multiples of the data_width
-    assert util.get_const_int(shape_vec[-1]) % lanes == 0, "Not a multiple of word size"
-    shape_vec[-1] = shape_vec[-1] // lanes
-    oshape = tuple(shape_vec)
-
-    def _bitpack(*indices):
-        ret = None
-        mask = tvm.const((1 << bits) - 1, pack_type)
-        for k in range(lanes):
-            idx = list(indices)
-            idx[-1] = idx[-1] * lanes + k
-            elem = data(*idx).astype(pack_type)
-            if k == 0:
-                ret = elem & mask
-            else:
-                val = (elem & mask) << tvm.const(k * bits, pack_type)
-                ret = ret | val
-        return ret
-
-    return tvm.compute(
-        oshape, _bitpack, name=name, tag='bitpack')
-
-
-@reg.register_compute("bitpack", level=15)
-def compute_bitpack(attrs, inputs, out):
-    lanes = attrs.get_int("lanes")
-    dtype = inputs[0].dtype
-    assert dtype == "int8"
-    width = 8
-    assert width % lanes == 0
-    bits = 8 // lanes
-    return bitpack(inputs[0], bits, dtype)
-
-reg.register_schedule("bitpack", _fschedule_broadcast)
-reg.register_pattern("bitpack", OpPattern.INJECTIVE)
diff --git a/python/vta/top/nnvm_graphpack.py b/python/vta/top/nnvm_graphpack.py
deleted file mode 100644
index 427001f..0000000
--- a/python/vta/top/nnvm_graphpack.py
+++ /dev/null
@@ -1,223 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""An NNVM implementation of graph packing."""
-
-import nnvm
-from nnvm.compiler import graph_attr, graph_util
-
-def _pack_batch_channel(data, dshape, bfactor, cfactor):
-    """Pack the data channel dimension.
-    """
-    assert dshape[0] % bfactor == 0
-    assert dshape[1] % cfactor == 0
-    data = nnvm.sym.reshape(data,
-                            shape=(dshape[0] // bfactor, bfactor,
-                                   dshape[1] // cfactor, cfactor,
-                                   dshape[2], dshape[3]))
-    data = nnvm.sym.transpose(
-        data, axes=(0, 2, 4, 5, 1, 3))
-    return data
-
-
-def _unpack_batch_channel(data, old_shape):
-    """Unpack the data channel dimension.
-    """
-    data = nnvm.sym.transpose(data, axes=(0, 4, 1, 5, 2, 3))
-    data = nnvm.sym.reshape(data, shape=old_shape)
-    return data
-
-
-def _pack_weight(data, dshape, cfactor):
-    """Pack the weight into packed format.
-    """
-    assert len(dshape) == 4
-    assert dshape[0] % cfactor == 0
-    assert dshape[1] % cfactor == 0
-    data = nnvm.sym.reshape(data,
-                            shape=(dshape[0] // cfactor, cfactor,
-                                   dshape[1] // cfactor, cfactor,
-                                   dshape[2], dshape[3]))
-    data = nnvm.sym.transpose(
-        data, axes=(0, 2, 4, 5, 1, 3))
-    return data
-
-
-def _pack_weight_conv2d_transpose(data, dshape, cfactor):
-    """Pack the weight into packed format.
-    """
-    assert len(dshape) == 4
-    assert dshape[0] % cfactor == 0
-    assert dshape[1] % cfactor == 0
-    data = nnvm.sym.reshape(data,
-                            shape=(dshape[0] // cfactor, cfactor,
-                                   dshape[1] // cfactor, cfactor,
-                                   dshape[2], dshape[3]))
-    data = nnvm.sym.transpose(
-        data, axes=(2, 0, 4, 5, 3, 1))
-    return data
-
-
-def _pack_bias(data, dshape, bfactor, cfactor):
-    """Pack the bias parameter.
-    """
-    assert len(dshape) == 3
-    assert dshape[0] % cfactor == 0
-    data = nnvm.sym.reshape(data,
-                            shape=(dshape[0] // cfactor,
-                                   cfactor, dshape[1],
-                                   dshape[2], 1))
-    data = nnvm.sym.transpose(
-        data, axes=(0, 2, 3, 4, 1))
-    # broadcast batch dimension to bfactor
-    data = nnvm.sym.broadcast_to(
-        data,
-        shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor))
-    return data
-
-
-def _get_shape(sym, shape_dict):
-    """Get the shape of a node.
-    """
-    return graph_util.infer_shape(
-        nnvm.graph.create(sym), **shape_dict)[1][0]
-
-
-def nnvm_graph_pack(graph,
-                    shape_dict,
-                    bfactor,
-                    cfactor,
-                    weight_bits,
-                    start_name="max_pool2d0",
-                    stop_name="global_avg_pool2d0"):
-    """Pack the graph into batch&channel packed format.
-
-    Parameters
-    ----------
-    graph : Graph
-       The input graph.
-
-    shape_dict : dict of str to shape
-       The input shape.
-
-    bfactor : int
-       The packing factor in batch
-
-    cfactor : int
-       The packing factor in channel
-
-    start_name: str, optional
-       Start packing from certain known node.
-
-    start_name: str, optional
-       Stop packing from certain known node.
-
-    Returns
-    -------
-    graph : Graph
-        The transformed graph.
-    """
-    graph = graph_attr.set_shape_inputs(graph, shape_dict)
-    graph = graph.apply("InferShape")
-    shape = graph.json_attr("shape")
-    gidx = graph.index
-    node_map = {}
-    dset = set()
-    start_pack = False
-
-    for nid, node in enumerate(gidx.nodes):
-        children = [node_map[e[0]] for e in node["inputs"]]
-        ishape = [shape[gidx.entry_id(e)] for e in node["inputs"]]
-        oshape = shape[gidx.entry_id(nid, 0)]
-        attrs = node.get("attrs", {})
-        node_name = node["name"]
-        op_name = node["op"]
-        get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)(
-            *c, name=n_n, **a)
-        if op_name == "null":
-            new_node = nnvm.symbol.Variable(node_name)
-            if start_name and node_name == start_name:
-                start_pack = True
-                new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
-            if start_pack and "_begin_state_" in node_name: # RNN -> CNN, pack
-                new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
-        elif node_name == start_name:
-            assert not start_pack
-            start_pack = True
-            new_node = get_clone(children, op_name, node_name, attrs)
-            new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
-        elif node_name == stop_name:
-            if start_pack:
-                start_pack = False
-                children[0] = _unpack_batch_channel(children[0], ishape[0])
-                new_node = getattr(nnvm.symbol, op_name)(
-                    *children, name=node_name, **attrs)
-            else:
-                new_node = get_clone(children, op_name, node_name, attrs)
-        elif op_name == "conv2d" and attrs.get("out_dtype", None) == "int32":
-            assert 8 % weight_bits == 0
-            w_lanes = 8 // weight_bits
-            if start_pack:
-                attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor)
-                attrs["kernel_layout"] = "OIHW%do%di%dp" % (cfactor, cfactor, w_lanes)
-                data, weight = children
-                weight = _pack_weight(weight, ishape[1], cfactor)
-                # insert bit packing when necessary
-                if w_lanes != 1:
-                    assert 8 % w_lanes == 0
-                    weight = nnvm.sym.bitpack(weight, lanes=w_lanes)
-                new_node = nnvm.sym.conv2d(
-                    data, weight, name=node_name, **attrs)
-            else:
-                new_node = get_clone(children, op_name, node_name, attrs)
-        elif op_name == "conv2d_transpose" and attrs.get("out_dtype", None) == "int32":
-            assert 8 % weight_bits == 0
-            w_lanes = 8 // weight_bits
-            if start_pack:
-                attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor)
-                attrs["kernel_layout"] = "IOHW%di%do%dp" % (cfactor, cfactor, w_lanes)
-                data, weight = children
-                weight = _pack_weight_conv2d_transpose(weight, ishape[1], cfactor)
-                new_node = nnvm.sym.conv2d_transpose(
-                    data, weight, name=node_name, **attrs)
-            else:
-                new_node = get_clone(children, op_name, node_name, attrs)
-        elif op_name.startswith("broadcast_") and tuple(ishape[0]) == tuple(ishape[1]):
-            new_node = get_clone(children, op_name, node_name, attrs)
-        elif op_name.startswith("broadcast") and len(ishape[1]) == 3:
-            if start_pack:
-                children[1] = _pack_bias(children[1], ishape[1], bfactor, cfactor)
-                new_node = getattr(nnvm.symbol, op_name)(
-                    *children, name=node_name, **attrs)
-            else:
-                new_node = get_clone(children, op_name, node_name, attrs)
-        elif op_name.startswith("elementwise_add"):
-            new_node = get_clone(children, op_name, node_name, attrs)
-        else:
-            new_node = get_clone(children, op_name, node_name, attrs)
-            dset.add(op_name)
-        node_map[nid] = new_node
-
-    assert len(graph.index.output_entries) == 1
-    ret = node_map[graph.index.output_entries[0][0]]
-    if start_pack:
-        oshape = shape[graph.index.output_entries[0][0]]
-        ret = _unpack_batch_channel(ret, oshape)
-    graph = nnvm.graph.create(ret)
-    graph = graph_attr.set_shape_inputs(graph, shape_dict)
-    graph = graph.apply("InferShape")
-    return graph
diff --git a/python/vta/top/nnvm_op.py b/python/vta/top/nnvm_op.py
deleted file mode 100644
index a38b217..0000000
--- a/python/vta/top/nnvm_op.py
+++ /dev/null
@@ -1,130 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""Namespace for supporting packed_conv2d + ewise variant of nnvm."""
-from __future__ import absolute_import as _abs
-
-import logging
-
-import tvm
-import topi
-
-from nnvm.top import registry as reg, OpPattern
-from nnvm.top import nn as _nn
-
-from .vta_conv2d import is_packed_layout
-from ..environment import get_env
-
-@tvm.register_func("nnvm.compiler.build_target", override=True)
-def _build(funcs, target, target_host):
-    tvm_t = tvm.target.create(target)
-    if tvm_t.device_name == "vta":
-        return tvm.build(funcs, target="ext_dev", target_host=target_host)
-    if tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu":
-        return tvm.build(funcs, target=target_host)
-    return tvm.build(funcs, target=target)
-
-@tvm.register_func("nnvm.compiler.lower", override=True)
-def _lower(sch, inputs, func_name, graph):
-    import traceback
-    # pylint: disable=broad-except
-    try:
-        f = tvm.lower(sch, inputs, name=func_name)
-        if "quantized_conv2d" in func_name:
-            logging.info(graph.ir(join_entry_attrs=["shape"]))
-    except Exception:
-        msg = traceback.format_exc()
-        msg += "Error during compile graph\n"
-        msg += "--------------------------\n"
-        msg += graph.ir(join_entry_attrs=["shape"])
-        raise RuntimeError(msg)
-    return f if isinstance(
-        f, (tvm.container.Array, tuple, list)) else [f]
-
-# override to force partition at copy
-reg.register_pattern("copy", OpPattern.INJECTIVE, level=15)
-
-@reg.register_compute("clip", level=15)
-def compute_clip(attrs, inputs, _):
-    """ Clip operator. """
-    x = inputs[0]
-    a_min = attrs.get_float("a_min")
-    a_max = attrs.get_float("a_max")
-    const_min = tvm.const(a_min, x.dtype)
-    const_max = tvm.const(a_max, x.dtype)
-    with tvm.tag_scope(topi.tag.ELEMWISE):
-        x = tvm.compute(
-            x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
-        x = tvm.compute(
-            x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
-    return x
-
-@reg.register_compute("conv2d", level=15)
-def compute_conv2d(attrs, inputs, out):
-    """ Compute definition of conv2d """
-    padding = attrs.get_int_tuple("padding")
-    strides = attrs.get_int_tuple("strides")
-    dilation = attrs.get_int_tuple("dilation")
-    groups = attrs.get_int("groups")
-    layout = attrs["layout"]
-    out_dtype = attrs['out_dtype']
-
-    assert dilation == (1, 1), "not support dilate now"
-    if is_packed_layout(layout):
-        if groups == 1:
-            assert groups == 1
-            env = get_env()
-            assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
-            assert env.LOG_OUT_WIDTH == 3, "only support 8bit inp for now"
-            inputs = list(inputs)
-            assert inputs[1].dtype == "int8"
-            return topi.nn.conv2d(inputs[0], inputs[1], strides,
-                                  padding, dilation, layout, out_dtype)
-        return topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides,
-                                         padding, dilation, groups, out_dtype)
-
-    with tvm.target.arm_cpu(tvm.target.current_target().model):
-        return _nn.compute_conv2d(attrs, inputs, out)
-
-@reg.register_schedule("conv2d", level=15)
-def schedule_conv2d(attrs, outs, target):
-    """ Schedule definition of conv2d """
-    layout = attrs["layout"]
-    groups = attrs.get_int('groups')
-
-    if is_packed_layout(layout):
-        target = tvm.target.create(target)
-        if target.device_name == "vta":
-            if groups == 1:
-                return topi.generic.schedule_conv2d_nchw(outs)
-            return topi.generic.schedule_group_conv2d_nchw(outs)
-        elif str(target).startswith("llvm"):
-            return tvm.create_schedule([x.op for x in outs])
-        else:
-            raise RuntimeError("not support target %s" % target)
-
-    with tvm.target.arm_cpu(tvm.target.current_target().model):
-        return _nn.schedule_conv2d(attrs, outs, tvm.target.current_target())
-
-@reg.register_alter_op_layout("conv2d", level=15)
-def alter_conv2d_layout(attrs, inputs, out):
-    layout = attrs['layout']
-    if is_packed_layout(layout):
-        return None
-
-    with tvm.target.arm_cpu(tvm.target.current_target().model):
-        return _nn.alter_conv2d_layout(attrs, inputs, out)
diff --git a/python/vta/top/op.py b/python/vta/top/op.py
index a02f62b..6aca07e 100644
--- a/python/vta/top/op.py
+++ b/python/vta/top/op.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=unused-argument, ungrouped-imports
-"""Namespace for supporting packed_conv2d + ewise variant of nnvm."""
+"""Namespace for supporting Relay operators on VTA."""
 from __future__ import absolute_import as _abs
 
 import tvm
diff --git a/scripts/tune_resnet_nnvm.py b/scripts/tune_resnet_nnvm.py
deleted file mode 100644
index d95ef43..0000000
--- a/scripts/tune_resnet_nnvm.py
+++ /dev/null
@@ -1,256 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""Perform ResNet autoTVM tuning on VTA using NNVM."""
-
-import argparse
-import os
-import time
-import numpy as np
-
-import tvm
-from tvm import rpc, autotvm
-from tvm.autotvm.measure.measure_methods import request_remote
-from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
-from tvm.contrib import graph_runtime, util
-from tvm.contrib.download import download
-
-import topi
-import nnvm.compiler
-import vta
-import vta.testing
-
-env = vta.get_env()
-
-def register_vta_tuning_tasks():
-    from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args
-
-    @tvm.tag_scope(tag=topi.tag.ELEMWISE)
-    def my_clip(x, a_min, a_max):
-        """Unlike topi's current clip, put min and max into two stages."""
-        const_min = tvm.const(a_min, x.dtype)
-        const_max = tvm.const(a_max, x.dtype)
-        x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
-        x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
-        return x
-
-    # init autotvm env to register VTA operator
-    TaskExtractEnv()
-
-    @autotvm.task.register("topi_nn_conv2d", override=True)
-    def _topi_nn_conv2d(*args, **kwargs):
-        assert not kwargs, "Do not support kwargs in template function call"
-        args = deserialize_args(args)
-        A, W = args[:2]
-
-        with tvm.target.vta():
-            res = topi.nn.conv2d(*args, **kwargs)
-            res = topi.right_shift(res, 8)
-            res = my_clip(res, 0, 127)
-            res = topi.cast(res, "int8")
-
-        if tvm.target.current_target().device_name == 'vta':
-            s = topi.generic.schedule_conv2d_nchw([res])
-        else:
-            s = tvm.create_schedule([res.op])
-        return s, [A, W, res]
-
-
-
-def generate_graph(sym, params, target, target_host):
-    # Populate the shape and data type dictionary
-    shape_dict = {"data": (1, 3, 224, 224)}
-    dtype_dict = {"data": 'float32'}
-    shape_dict.update({k: v.shape for k, v in params.items()})
-    dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
-
-    # Apply NNVM graph optimization passes
-    sym = vta.graph.clean_cast(sym)
-    sym = vta.graph.clean_conv_fuse(sym)
-    assert env.BLOCK_IN == env.BLOCK_OUT
-    sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT)
-
-    # Compile NNVM graph
-    with nnvm.compiler.build_config(opt_level=3):
-        with vta.build_config():
-            graph, lib, params = nnvm.compiler.build(
-                sym, target, shape_dict, dtype_dict,
-                params=params, target_host=target_host)
-
-    return graph, lib, params
-
-
-def extract_tasks(sym, params, target, target_host):
-    # Populate the shape and data type dictionary
-    shape_dict = {"data": (1, 3, 224, 224)}
-    dtype_dict = {"data": 'float32'}
-    shape_dict.update({k: v.shape for k, v in params.items()})
-    dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
-
-    # Apply NNVM graph optimization passes
-    sym = vta.graph.clean_cast(sym)
-    sym = vta.graph.clean_conv_fuse(sym)
-    assert env.BLOCK_IN == env.BLOCK_OUT
-    sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT)
-
-    with vta.build_config():
-        tasks = autotvm.task.extract_from_graph(graph=sym, shape=shape_dict, dtype=dtype_dict, target=target,
-                                                params=params, symbols=(nnvm.sym.conv2d,), target_host=target_host)
-    return tasks
-
-
-def download_model():
-    url = "https://github.com/uwsaml/web-data/raw/master/vta/models/"
-    categ_fn = 'synset.txt'
-    graph_fn = 'resnet18_qt8.json'
-    params_fn = 'resnet18_qt8.params'
-    data_dir = '_data'
-    if not os.path.exists(data_dir):
-        os.makedirs(data_dir)
-
-    for file in [categ_fn, graph_fn, params_fn]:
-        if not os.path.isfile(file):
-            download(os.path.join(url, file), os.path.join(data_dir, file))
-
-    sym = nnvm.graph.load_json(open(os.path.join(data_dir, graph_fn)).read())
-    params = nnvm.compiler.load_param_dict(open(os.path.join(data_dir, params_fn), 'rb').read())
-
-    return sym, params
-
-
-def tune_tasks(tasks,
-               measure_option,
-               tuner='xgb',
-               n_trial=1000,
-               early_stopping=None,
-               log_filename='tuning.log',
-               use_transfer_learning=True,
-               try_winograd=True):
-    # create tmp log file
-    tmp_log_file = log_filename + ".tmp"
-    if os.path.exists(tmp_log_file):
-        os.remove(tmp_log_file)
-
-    for i, tsk in enumerate(reversed(tasks)):
-        prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
-
-        # create tuner
-        if tuner == 'xgb' or tuner == 'xgb-rank':
-            tuner_obj = XGBTuner(tsk, loss_type='rank')
-        elif tuner == 'ga':
-            tuner_obj = GATuner(tsk, pop_size=50)
-        elif tuner == 'random':
-            tuner_obj = RandomTuner(tsk)
-        elif tuner == 'gridsearch':
-            tuner_obj = GridSearchTuner(tsk)
-        else:
-            raise ValueError("Invalid tuner: " + tuner)
-
-        if use_transfer_learning:
-            if os.path.isfile(tmp_log_file):
-                tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
-
-        # do tuning
-        n_trial_ = min(n_trial, len(tsk.config_space))
-        tuner_obj.tune(n_trial_,
-                       early_stopping=early_stopping,
-                       measure_option=measure_option,
-                       callbacks=[
-                           autotvm.callback.progress_bar(n_trial_, prefix=prefix),
-                           autotvm.callback.log_to_file(tmp_log_file)])
-
-    # pick best records to a cache file
-    autotvm.record.pick_best(tmp_log_file, log_filename)
-    os.remove(tmp_log_file)
-
-if __name__ == '__main__':
-
-    # Get tracker info from env
-    tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
-    tracker_port = os.environ.get("TVM_TRACKER_PORT", None)
-    if not tracker_host or not tracker_port:
-        print("Set your AutoTVM tracker node host and port variables to run the autotuner")
-        exit()
-
-    # Download model
-    sym, params = download_model()
-
-    # Register VTA tuning tasks
-    register_vta_tuning_tasks()
-
-    # Extract tasks
-    print("Extracting tasks...")
-    target = tvm.target.vta()
-    target_host = env.target_host
-    tasks = extract_tasks(sym, params, target, target_host)
-
-    # Perform Autotuning
-    print("Tuning...")
-    tuning_opt = {
-        'log_filename': 'resnet-18.log',
-
-        'tuner': 'random',
-        'n_trial': 1e9,
-        'early_stopping': None,
-
-        'measure_option':  autotvm.measure_option(
-                builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
-                runner=autotvm.RPCRunner(env.TARGET, tracker_host, int(tracker_port),
-                    number=4, repeat=3, timeout=60,
-                    check_correctness=True))
-    }
-    tune_tasks(tasks, **tuning_opt)
-
-    # compile kernels with history best records
-    with autotvm.tophub.context(target, extra_files=[tuning_opt['log_filename']]):
-
-        # ResNet parameters
-        input_shape = (1, 3, 224, 224)
-        dtype = 'float32'\
-
-        # Compile network
-        print("Compiling network with best tuning parameters...")
-        graph, lib, params = generate_graph(sym, params, target, target_host)
-        input_shape = (1, 3, 224, 224)
-        dtype = 'float32'
-
-        # Export library
-        tmp = util.tempdir()
-        filename = "net.tar"
-        lib.export_library(tmp.relpath(filename))
-
-        # Upload module to device
-        print("Upload...")
-        remote = autotvm.measure.request_remote(env.TARGET, tracker_host, tracker_port, timeout=10000)
-        remote.upload(tmp.relpath(filename))
-        rlib = remote.load_module(filename)
-
-        # Upload parameters to device
-        ctx = remote.context(str(target), 0)
-        rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
-        data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
-        module = graph_runtime.create(graph, rlib, ctx)
-        module.set_input('data', data_tvm)
-        module.set_input(**rparams)
-
-        # Evaluate
-        print("Evaluate inference time cost...")
-        ftimer = module.module.time_evaluator("run", ctx, number=3, repeat=3)
-        prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
-        print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
-              (np.mean(prof_res), np.std(prof_res)))
-