Fixed bugs for SSD sorting and multbox detection (#1578)

diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py
index 4d4e402..361208b 100644
--- a/topi/python/topi/cuda/nms.py
+++ b/topi/python/topi/cuda/nms.py
@@ -7,19 +7,155 @@
 from topi.vision import nms
 
 
-def sort_ir(data, index, output, axis, is_descend):
-    """Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
+def sort_pre_ir(index, sizes_out, axis_mul_before, axis_mul_after):
+    """Low level IR routing subfunction 1/4 for computing segments' staring locatons.
+
+    Parameters
+    ----------
+    index : Buffer
+        Buffer of number of valid output boxes.
+
+    sizes_out : Buffer
+        Output buffer of start locations of each sorting segment.
+
+    axis_mul_before : int
+        The multiplication result of axis dimensions before axis.
+
+    axis_mul_after : int
+        The multiplication result of axis dimensions after axis.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    max_threads = int(
+        tvm.target.current_target(allow_none=False).max_num_threads)
+    tx = tvm.thread_axis("threadIdx.x")
+    bx = tvm.thread_axis("blockIdx.x")
+    ib = tvm.ir_builder.create()
+    p_index = ib.buffer_ptr(index)
+    dshape = sizes_out.shape
+    sizes = ib.buffer_ptr(sizes_out)
+    nthread_tx = max_threads
+    nthread_bx = dshape[0] // max_threads + 1
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    ib.scope_attr(bx, "thread_extent", nthread_bx)
+    tid = bx * max_threads + tx
+
+    with ib.if_scope(tid < axis_mul_before * axis_mul_after):
+        sizes[tid] = p_index[tid]
+
+    # scan
+    with ib.if_scope(tid < 1):
+        with ib.for_range(0, axis_mul_before * axis_mul_after - 1, name="k") as k:
+            sizes[k + 1] += sizes[k]
+    body = ib.get()
+    return body
+
+
+def sort_pre_ir_data(data, index, sizes_in, data_out, index_out, \
+                     axis, axis_mul_before, axis_mul_after):
+    """Low level IR routing subfunction 2/4 for flattening data and indices into segmented format.
 
     Parameters
     ----------
     data: Buffer
-        2D Buffer of input boxes' score with shape [batch_size, num_anchors].
+        Buffer of output boxes with class and score.
 
     index : Buffer
-        Buffer of number of valid number of boxes.
+        Buffer of number of valid output boxes.
 
-    output : Buffer
-        Output buffer of indicies of sorted tensor.
+    sizes_in : Buffer
+        Buffer of start locations of each sorting segment.
+
+    data_out : Buffer
+        Buffer of flattened segmented data.
+
+    index_out : Buffer
+        Buffer of flattened segmented indices.
+
+    axis : int
+        The axis used for sorting.
+
+    axis_mul_before : int
+        The multiplication result of axis dimensions before axis.
+
+    axis_mul_after : int
+        The multiplication result of axis dimensions after axis.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    ib = tvm.ir_builder.create()
+    sizes = ib.buffer_ptr(sizes_in)
+    p_index = ib.buffer_ptr(index)
+    p_data = ib.buffer_ptr(data)
+    data_new = ib.buffer_ptr(data_out)
+    index_new = ib.buffer_ptr(index_out)
+    max_threads = int(
+        tvm.target.current_target(allow_none=False).max_num_threads)
+    tx = tvm.thread_axis("threadIdx.x")
+    bx = tvm.thread_axis("blockIdx.x")
+    dshape = tvm.max(sizes_in.shape[0], p_index[0])
+    nthread_tx = max_threads
+    nthread_bx = dshape // max_threads + 1
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    ib.scope_attr(bx, "thread_extent", nthread_bx)
+    tid = bx * max_threads + tx
+    with ib.if_scope(axis_mul_before * axis_mul_after > 1):
+        with ib.if_scope(tid < axis_mul_before * axis_mul_after):
+            i = tid / axis_mul_after
+            j = tid % axis_mul_after
+            current_sort_num = p_index[tid]
+            base_idx = i * data.shape[axis] * axis_mul_after + j
+            with ib.for_range(0, current_sort_num, name="k") as k:
+                full_idx = base_idx + k * axis_mul_after
+                with ib.if_scope(tid == 0):
+                    start = 0
+                with ib.else_scope():
+                    start = sizes[tid-1]
+                index_new[start + k] = k
+                data_new[start + k] = p_data[full_idx]
+    with ib.else_scope():
+        with ib.if_scope(tid == 0):
+            with ib.for_range(0, p_index[0], name="k") as k:
+                index_new[k] = k
+
+    body = ib.get()
+    return body
+
+def sort_oet_ir(data, index, new_data, new_index, loc, out_index, axis_mul_before, \
+                axis_mul_after, axis, is_descend):
+    """Low level IR routing subfunction 3/4 for Odd-Even-Transposition sorting.
+
+    Parameters
+    ----------
+    data: Buffer
+        Buffer of output boxes with class and score.
+
+    index : Buffer
+        Buffer of number of valid output boxes.
+
+    new_data : Buffer
+        Buffer of flattened segmented data.
+
+    new_index : Buffer
+        Buffer of flattened segmented indices.
+
+    loc : Buffer
+        Buffer of start locations of each sorting segment.
+
+    out_index : Buffer
+        Output buffer of output box indexes sorted by score in a flattened segmented format.
+
+    axis_mul_before : int
+        The multiplication result of axis dimensions before axis.
+
+    axis_mul_after : int
+        The multiplication result of axis dimensions after axis.
 
     axis : int
         The axis used for sorting.
@@ -32,15 +168,197 @@
     stmt : Stmt
         The result IR statement.
     """
-
     max_threads = int(
         tvm.target.current_target(allow_none=False).max_num_threads)
     tx = tvm.thread_axis("threadIdx.x")
     bx = tvm.thread_axis("blockIdx.x")
     ib = tvm.ir_builder.create()
+    dshape = loc.shape
+    fshape = data.shape[axis] * dshape[0]
+    temp_data = ib.allocate(
+        "float32", dshape, name="temp_data", scope="local")
     p_data = ib.buffer_ptr(data)
     p_index = ib.buffer_ptr(index)
+    data_new = ib.buffer_ptr(new_data)
+    index_new = ib.buffer_ptr(new_index)
+    index_out = ib.buffer_ptr(out_index)
+    sizes = ib.buffer_ptr(loc)
+    nthread_tx = max_threads
+    nthread_bx = fshape // max_threads + 1
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    ib.scope_attr(bx, "thread_extent", nthread_bx)
+    tid = bx * max_threads + tx
+
+    with ib.if_scope(axis_mul_before * axis_mul_after > 1):
+        with ib.if_scope(tid < axis_mul_before * axis_mul_after):
+            with ib.if_scope(tid == 0):
+                start = 0
+            with ib.else_scope():
+                start = sizes[tid-1]
+            # OddEvenTransposeSort
+            with ib.for_range(0, p_index[tid], name="k") as k:
+                with ib.for_range(0, p_index[tid] - 1, name="i") as i:
+                    with ib.if_scope(i % 2 == k % 2):
+                        with ib.if_scope(((data_new[i+start] < data_new[i+start+1]) == is_descend)):
+                            temp_data[tid] = data_new[i+start]
+                            data_new[i+start] = data_new[i+start+1]
+                            data_new[i+start+1] = temp_data[tid]
+                            index_out[tid] = index_new[i+start]
+                            index_new[i+start] = index_new[i+start+1]
+                            index_new[i+start+1] = index_out[tid]
+        with ib.if_scope(tid < 1):
+            with ib.for_range(0, sizes[dshape[0] - 1], name="i") as i:
+                index_out[i] = index_new[i]
+    with ib.else_scope():
+        with ib.for_range(0, fshape, name="k", for_type="unroll") as k:
+            with ib.if_scope(tvm.all(k % 2 == tid % 2, tid < fshape)):
+                with ib.if_scope(k % 2 == 0):
+                    with ib.if_scope(tvm.all(tid + 1 < fshape, (p_data[tid] < p_data[tid+1]) \
+                                             == is_descend)):
+                        data_new[tid] = p_data[tid+1]
+                        index_out[tid] = index_new[tid+1]
+                    with ib.else_scope():
+                        data_new[tid] = p_data[tid]
+                        index_out[tid] = index_new[tid]
+                with ib.else_scope():
+                    with ib.if_scope(tvm.all(tid + 1 < fshape, (data_new[tid] < data_new[tid+1]) \
+                                             == is_descend)):
+                        p_data[tid] = data_new[tid+1]
+                        index_new[tid] = index_out[tid+1]
+                    with ib.else_scope():
+                        p_data[tid] = data_new[tid]
+                        index_new[tid] = index_out[tid]
+            with ib.if_scope(tvm.all(k % 2 != tid % 2, tid < fshape)):
+                with ib.if_scope(k % 2 == 0):
+                    with ib.if_scope(tvm.all(tid > 0, (p_data[tid-1] < p_data[tid]) == is_descend)):
+                        data_new[tid] = p_data[tid-1]
+                        index_out[tid] = index_new[tid-1]
+                    with ib.else_scope():
+                        data_new[tid] = p_data[tid]
+                        index_out[tid] = index_new[tid]
+                with ib.else_scope():
+                    with ib.if_scope(tvm.all(tid > 0, (data_new[tid-1] < data_new[tid]) \
+                                             == is_descend)):
+                        p_data[tid] = data_new[tid-1]
+                        index_new[tid] = index_out[tid-1]
+                    with ib.else_scope():
+                        p_data[tid] = data_new[tid]
+                        index_new[tid] = index_out[tid]
+        with ib.if_scope(fshape % 2 == 1):
+            with ib.if_scope(tid < 1):
+                with ib.for_range(0, fshape, name="k") as k:
+                    index_out[tid] = index_new[tid]
+    body = ib.get()
+    return body
+
+
+def sort_ir_out(data, index, new_index, loc, output, axis_mul_before, axis_mul_after, axis):
+    """Low level IR routing subfunction 4/4 for writing sorted indices to output format.
+
+    Parameters
+    ----------
+    data: Buffer
+        Buffer of output boxes with class and score.
+
+    index : Buffer
+        Buffer of number of valid output boxes.
+
+    new_index : Buffer
+        Buffer of sorted indices in a flatten format.
+
+    loc : Buffer
+        Buffer of start locations of each sorting segment.
+
+    output : Buffer
+        Output buffer of output box indexes sorted by score.
+
+    axis_mul_before : int
+        The multiplication result of axis dimensions before axis.
+
+    axis_mul_after : int
+        The multiplication result of axis dimensions after axis.
+
+    axis : int
+        The axis used for sorting.
+
+    is_descend : bool
+        If the sorted data is in descending order.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    max_threads = int(
+        tvm.target.current_target(allow_none=False).max_num_threads)
+    tx = tvm.thread_axis("threadIdx.x")
+    bx = tvm.thread_axis("blockIdx.x")
+    ib = tvm.ir_builder.create()
+    dshape = tvm.max(loc.shape[0], data.shape[axis])
+    p_index = ib.buffer_ptr(index)
+    index_new = ib.buffer_ptr(new_index)
+    sizes = ib.buffer_ptr(loc)
     p_out = ib.buffer_ptr(output)
+    nthread_tx = max_threads
+    nthread_bx = dshape // max_threads + 1
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    ib.scope_attr(bx, "thread_extent", nthread_bx)
+    tid = bx * max_threads + tx
+
+    with ib.if_scope(axis_mul_before * axis_mul_after > 1):
+        with ib.if_scope(tid < axis_mul_before * axis_mul_after):
+            i = tid / axis_mul_after
+            j = tid % axis_mul_after
+            base_idx = i * data.shape[axis] * axis_mul_after + j
+            with ib.for_range(0, data.shape[axis], name="k") as k:
+                with ib.if_scope(tid == 0):
+                    start = 0
+                with ib.else_scope():
+                    start = sizes[tid-1]
+                p_out[base_idx + k * axis_mul_after] = tvm.select(
+                    k < p_index[tid], index_new[k+start], k)
+    with ib.else_scope():
+        with ib.if_scope(tid < data.shape[axis]):
+            p_out[tid] = tvm.select(tid < p_index[0], index_new[tid], tid)
+
+    body = ib.get()
+    return body
+
+
+def sort_gpu(data, data_buf, index, index_buf, output_buf, axis, is_descend):
+    """Function to generate low level IR to do sorting on the GPU, use it by calling sort_gpu.
+
+    Parameters
+    ----------
+    data: tvm.Tensor
+        3-D tensor with shape [batch_size, num_anchors, 6].
+        The last dimension should be in format of
+        [class_id, score, box_left, box_top, box_right, box_bottom].
+
+    data_buf: Buffer
+        2D Buffer of input boxes' score with shape [batch_size, num_anchors].
+
+    index : tvm.Tensor
+        1-D tensor for valid number of boxes.
+
+    index_buf : Buffer
+        Buffer of number of valid number of boxes.
+
+    output_buf : Buffer
+        Output buffer of indicies of sorted tensor.
+
+    axis : int
+        The axis used for sorting.
+
+    is_descend : bool
+        If the sorted data is in descending order.
+
+    Returns
+    -------
+    out : tvm.Tensor
+        3-D tensor with shape [batch_size, num_anchors].
+    """
+
     ndim = len(data.shape)
     assert data.dtype == "float32", "Currently only supports input dtype to be float32"
     assert axis < ndim, "Axis out of boundary for input ndim %d" % ndim
@@ -55,89 +373,60 @@
         elif i > axis:
             axis_mul_after *= data.shape[i]
 
-    dshape = 0
-    for i in range(0, len(index.shape)):
-        dshape += index.shape[i]
-    dshape = tvm.select(dshape > axis_mul_before*axis_mul_after, dshape,
-                        axis_mul_before*axis_mul_after)
+    dshape = axis_mul_before*axis_mul_after
+    fshape = data.shape[axis] * dshape
 
-    sizes_temp = ib.allocate(
-        "int32", dshape, name="sizes_temp", scope="global")
-    sizes = ib.allocate("int32", dshape, name="sizes", scope="global")
-    temp_index = ib.allocate("int32", dshape, name="temp_index", scope="local")
-    temp_data = ib.allocate("float32", dshape, name="temp_data", scope="local")
-    data_new = ib.allocate("float32", dshape, name="data_new", scope="global")
-    index_new = ib.allocate("int32", dshape, name="index_new", scope="global")
-    nthread_tx = max_threads
-    nthread_bx = dshape // max_threads + 1
-    ib.scope_attr(tx, "thread_extent", nthread_tx)
-    ib.scope_attr(bx, "thread_extent", nthread_bx)
-    tid = bx * max_threads + tx
+    loc_buf = api.decl_buffer(dshape, index.dtype, "sizes", data_alignment=8)
+    new_index_buf = api.decl_buffer(
+        fshape, index.dtype, "index_new", data_alignment=8)
+    out_index_buf = api.decl_buffer(
+        fshape, index.dtype, "index_out", data_alignment=8)
+    new_data_buf = api.decl_buffer(
+        dshape, data.dtype, "data_new", data_alignment=8)
 
-    with ib.if_scope(tid < axis_mul_before * axis_mul_after):
-        sizes[tid] = p_index[tid]
-        sizes_temp[tid] = p_index[tid]
+    loc = \
+        tvm.extern([(dshape,)],
+                   [index],
+                   lambda ins, outs: sort_pre_ir(
+                       ins[0], outs[0], axis_mul_before, axis_mul_after),
+                   dtype=[index.dtype],
+                   in_buffers=index_buf,
+                   out_buffers=[loc_buf],
+                   tag="sorting_prepare")
 
-    with ib.if_scope(tid < axis_mul_before * axis_mul_after):
-        with ib.for_range(0, tvm.floor(tvm.sqrt((axis_mul_before * axis_mul_after) \
-             .astype("float32"))) + 1, name="k") as k:
-            with ib.if_scope(tid - (tvm.const(1, "int32") << k) >= 0):
-                with ib.if_scope(k % 2 == 0):
-                    sizes[tid] += sizes_temp[tid - (
-                        tvm.const(1, "int32") << k)]
-                    sizes_temp[tid] = sizes[tid]
-                with ib.else_scope():
-                    sizes_temp[tid] += sizes[tid - (
-                        tvm.const(1, "int32") << k)]
-                    sizes[tid] = sizes_temp[tid]
+    data_new, index_new = \
+        tvm.extern([(dshape,), (fshape,)],
+                   [data, index, loc],
+                   lambda ins, outs: sort_pre_ir_data(
+                       ins[0], ins[1], ins[2], outs[0], outs[1], axis,
+                       axis_mul_before, axis_mul_after),
+                   dtype=[data.dtype, index.dtype],
+                   in_buffers=[data_buf, index_buf, loc_buf],
+                   out_buffers=[new_data_buf, new_index_buf],
+                   tag="sorting_data")
 
-    with ib.if_scope(tid < axis_mul_before * axis_mul_after):
-        i = tid / axis_mul_after
-        j = tid % axis_mul_after
-        current_sort_num = p_index[tid]
-        base_idx = i * data.shape[axis] * axis_mul_after + j
-        with ib.for_range(0, current_sort_num, name="k") as k:
-            full_idx = base_idx + k * axis_mul_after
-            with ib.if_scope(tid == 0):
-                start = 0
-            with ib.else_scope():
-                start = sizes[tid-1]
-            index_new[start + k] = k
-            data_new[start + k] = p_data[full_idx]
-
-    with ib.if_scope(tid < axis_mul_before * axis_mul_after):
-        with ib.if_scope(tid == 0):
-            start = 0
-        with ib.else_scope():
-            start = sizes[tid-1]
-        # OddEvenTransposeSort
-        with ib.for_range(0, p_index[tid], name="k") as k:
-            with ib.for_range(0, p_index[tid] - 1, name="i") as i:
-                with ib.if_scope(i % 2 == (k & 1)):
-                    with ib.if_scope(((data_new[i+start] < data_new[i+start+1]) ^
-                                      is_descend) == False):
-                        temp_data[tid] = data_new[i+start]
-                        data_new[i+start] = data_new[i+start+1]
-                        data_new[i+start+1] = temp_data[tid]
-                        temp_index[tid] = index_new[i+start]
-                        index_new[i+start] = index_new[i+start+1]
-                        index_new[i+start+1] = temp_index[tid]
-
-    with ib.if_scope(tid < axis_mul_before * axis_mul_after):
-        i = tid / axis_mul_after
-        j = tid % axis_mul_after
-        current_sort_num = p_index[tid]
-        base_idx = i * data.shape[axis] * axis_mul_after + j
-        with ib.for_range(0, data.shape[axis], name="k") as k:
-            with ib.if_scope(tid == 0):
-                start = 0
-            with ib.else_scope():
-                start = sizes[tid-1]
-            p_out[base_idx + k * axis_mul_after] = tvm.select(
-                k < current_sort_num,
-                index_new[k+start], k)
-    body = ib.get()
-    return body
+    index_out = \
+        tvm.extern([(fshape,)],
+                   [data, index, data_new, index_new, loc],
+                   lambda ins, outs: sort_oet_ir(
+                       ins[0], ins[1], ins[2], ins[3], ins[4], outs[0],
+                       axis_mul_before, axis_mul_after, axis, is_descend),
+                   dtype=[index.dtype],
+                   in_buffers=[data_buf, index_buf,
+                               new_data_buf, new_index_buf, loc_buf],
+                   out_buffers=[out_index_buf],
+                   tag="sorting_oet")
+    out = \
+        tvm.extern([data.shape],
+                   [data, index, index_out, loc],
+                   lambda ins, outs: sort_ir_out(
+                       ins[0], ins[1], ins[2], ins[3], outs[0],
+                       axis_mul_before, axis_mul_after, axis),
+                   dtype=[index.dtype],
+                   in_buffers=[data_buf, index_buf, out_index_buf, loc_buf],
+                   out_buffers=output_buf,
+                   tag="sorting_output")
+    return out
 
 
 def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk):
@@ -333,15 +622,8 @@
     sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype,
                                       "sort_tensor_buf", data_alignment=8)
 
-    sort_tensor = \
-        tvm.extern(score_shape,
-                   [score_tensor, valid_count],
-                   lambda ins, outs: sort_ir(
-                       ins[0], ins[1], outs[0], score_axis, True),
-                   dtype=sort_tensor_dtype,
-                   in_buffers=[score_tensor_buf, valid_count_buf],
-                   out_buffers=sort_tensor_buf,
-                   name="nms_sort")
+    sort_tensor = sort_gpu(score_tensor, score_tensor_buf, valid_count,
+                           valid_count_buf, sort_tensor_buf, score_axis, True)
     out = \
         tvm.extern(data.shape,
                    [data, sort_tensor, valid_count],
diff --git a/topi/python/topi/cuda/ssd/multibox.py b/topi/python/topi/cuda/ssd/multibox.py
index c22e7a5..3c013c4 100644
--- a/topi/python/topi/cuda/ssd/multibox.py
+++ b/topi/python/topi/cuda/ssd/multibox.py
@@ -1,4 +1,4 @@
-# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, too-many-function-args
 """SSD multibox operators"""
 from __future__ import absolute_import as _abs
 import math
@@ -13,6 +13,7 @@
 from topi.vision.ssd import multibox_transform_loc
 from ..nms import nms
 
+
 def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
     """Low level IR routing for multibox_prior operator.
 
@@ -41,7 +42,8 @@
     stmt : Stmt
         The result IR statement.
     """
-    max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads))
+    max_threads = int(math.sqrt(
+        tvm.target.current_target(allow_none=False).max_num_threads))
     tx = tvm.thread_axis("threadIdx.x")
     ty = tvm.thread_axis("threadIdx.y")
     bx = tvm.thread_axis("blockIdx.x")
@@ -76,7 +78,8 @@
 
             for k in range(num_sizes + num_ratios - 1):
                 w = tvm.select(k < num_sizes,
-                               size_ratio_concat[k] * in_height / in_width / 2.0,
+                               size_ratio_concat[
+                                   k] * in_height / in_width / 2.0,
                                size_ratio_concat[0] * in_height / in_width *
                                math.sqrt(size_ratio_concat[k + 1]) / 2.0)
                 h = tvm.select(k < num_sizes, size_ratio_concat[k] / 2.0,
@@ -93,7 +96,7 @@
 
 
 @multibox_prior.register(["cuda", "gpu"])
-def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1), \
+def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1),
                        offsets=(0.5, 0.5), clip=False):
     """Generate prior(anchor) boxes from data, sizes and ratios.
 
@@ -124,31 +127,114 @@
     """
     num_sizes = len(sizes)
     num_ratios = len(ratios)
-    oshape = (1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4)
+    oshape = (
+        1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4)
     out = tvm.extern(oshape, [data], lambda ins, outs:
-                     multibox_prior_ir(ins[0], outs[0], sizes, ratios, steps, offsets),
+                     multibox_prior_ir(
+                         ins[0], outs[0], sizes, ratios, steps, offsets),
                      tag="multibox_prior")
     if clip:
         out = topi.clip(out, 0, 1)
     return out
 
 
-def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, threshold, variances):
-    """Low level IR routing for transform location in multibox_detection operator.
+def transform_loc_pre(cls_prob, valid_count, temp_flag, temp_id, temp_score_out, threshold):
+    """Low level IR routing for transform location data preparation.
 
     Parameters
     ----------
     cls_prob : Buffer
         Buffer of class probabilities.
 
+    valid_count : Buffer
+        Buffer of number of valid output boxes.
+
+    temp_flag : Buffer
+        Output intermediate result buffer
+
+    temp_id : Buffer
+        Output intermediate result buffer
+
+    temp_score_out : Buffer
+        Output buffer
+
+    threshold : float
+        Threshold to be a positive prediction.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    batch_size = cls_prob.shape[0]
+    num_classes = cls_prob.shape[1]
+    num_anchors = cls_prob.shape[2]
+
+    max_threads = int(
+        tvm.target.current_target(allow_none=False).max_num_threads)
+    ib = tvm.ir_builder.create()
+    score = ib.buffer_ptr(temp_score_out)
+    cls_id = ib.buffer_ptr(temp_id)
+    flag = ib.buffer_ptr(temp_flag)
+    tx = tvm.thread_axis("threadIdx.x")
+    bx = tvm.thread_axis("blockIdx.x")
+    nthread_tx = max_threads
+    nthread_bx = (batch_size * num_anchors * num_classes) // max_threads + 1
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    ib.scope_attr(bx, "thread_extent", nthread_bx)
+    tid = bx * max_threads + tx
+    p_cls_prob = ib.buffer_ptr(cls_prob)
+    p_valid_count = ib.buffer_ptr(valid_count)
+
+    with ib.if_scope(tid < batch_size * num_anchors):
+        n = tid / num_anchors  # number of batches
+        i = tid % num_anchors  # number of anchors
+        score[i] = -1.0
+        cls_id[i] = 0
+        p_valid_count[n] = 0
+        with ib.for_range(0, num_classes-1, name="k") as k:
+            temp = p_cls_prob[n * num_anchors * num_classes + (k + 1) * num_anchors + i]
+            with ib.if_scope(temp > score[i]):
+                cls_id[i] = k + 1
+                score[i] = temp
+        with ib.if_scope(tvm.all(cls_id[i] > 0, score[i] < threshold)):
+            cls_id[i] = 0
+        with ib.if_scope(cls_id[i] > 0):
+            flag[i] = 1
+        with ib.else_scope():
+            flag[i] = 0
+
+        with ib.if_scope(tid < batch_size):
+            with ib.for_range(0, num_anchors, name="k") as k:
+                with ib.if_scope(k > 0):
+                    flag[tid * num_anchors +
+                         k] += flag[tid * num_anchors + k - 1]
+            p_valid_count[n] = flag[tid * num_anchors + num_anchors - 1]
+
+    body = ib.get()
+    return body
+
+
+def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
+                     out, clip, variances, batch_size, num_classes, num_anchors):
+    """Low level IR routing for transform location in multibox_detection operator.
+
+    Parameters
+    ----------
     loc_pred : Buffer
         Buffer of location regression predictions.
 
     anchor : Buffer
         Buffer of prior anchor boxes.
 
-    valid_count : Buffer
-        Buffer of number of valid output boxes.
+    temp_flag : Buffer
+        Intermediate result buffer.
+
+    temp_id : Buffer
+        Intermediate result buffer.
+
+    temp_score_in : Buffer
+        Input buffer which stores intermediate results.
 
     out : Buffer
         Output buffer.
@@ -156,12 +242,18 @@
     clip : boolean
         Whether to clip out-of-boundary boxes.
 
-    threshold : float
-        Threshold to be a positive prediction.
-
     variances : tuple of float
         Variances to be decoded from box regression output.
 
+    batch_size : int
+        Batch size
+
+    num_classes : int
+        Number of classes
+
+    num_anchors : int
+        Number of anchors
+
     Returns
     -------
     stmt : Stmt
@@ -187,21 +279,16 @@
         ow = tvm.exp(pw * vw) * aw / 2.0
         oh = tvm.exp(ph * vh) * ah / 2.0
         return tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox - ow)), ox - ow), \
-               tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \
-               tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \
-               tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh)
+            tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \
+            tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \
+            tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh)
 
-    batch_size = cls_prob.shape[0]
-    num_classes = cls_prob.shape[1]
-    num_anchors = cls_prob.shape[2]
-
+    max_threads = int(
+        tvm.target.current_target(allow_none=False).max_num_threads)
     ib = tvm.ir_builder.create()
-    temp_score = ib.allocate('float32', (batch_size * (num_classes -1) * num_anchors, \
-                 ), name="temp_score", scope="global")
-    score = ib.allocate('float32', (batch_size * num_anchors, ), name="score", scope="local")
-    cls_id = ib.allocate('int32', (batch_size * num_anchors, ), name="id", scope="local")
-    flag = ib.allocate('int32', (batch_size * num_anchors, ), name="flag", scope="global")
-    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    score = ib.buffer_ptr(temp_score_in)
+    cls_id = ib.buffer_ptr(temp_id)
+    flag = ib.buffer_ptr(temp_flag)
     tx = tvm.thread_axis("threadIdx.x")
     bx = tvm.thread_axis("blockIdx.x")
     nthread_tx = max_threads
@@ -209,42 +296,13 @@
     ib.scope_attr(tx, "thread_extent", nthread_tx)
     ib.scope_attr(bx, "thread_extent", nthread_bx)
     tid = bx * max_threads + tx
-    p_cls_prob = ib.buffer_ptr(cls_prob)
     p_loc_pred = ib.buffer_ptr(loc_pred)
     p_anchor = ib.buffer_ptr(anchor)
-    p_valid_count = ib.buffer_ptr(valid_count)
     p_out = ib.buffer_ptr(out)
-    with ib.if_scope(tid < batch_size * num_anchors * num_classes):
-        n = tid / (num_anchors * num_classes)
-        j = (tid % (num_anchors * num_classes)) / num_anchors
-        i = tid % num_anchors
-        with ib.if_scope(j > 0):
-            temp_score[n * num_anchors * num_classes + i * (num_classes - 1) + j-1] = \
-            p_cls_prob[tid]
-        p_valid_count[n] = 0
+
     with ib.if_scope(tid < batch_size * num_anchors):
-        n = tid / num_anchors
-        i = tid % num_anchors
-        score[tid] = -1.0
-        cls_id[tid] = 0
-        with ib.for_range(0, num_classes-1, name="k") as k:
-            temp = temp_score[tid * (num_classes-1) + k]
-            cls_id[tid] = tvm.select(temp > score[tid], k + 1, cls_id[tid])
-            score[tid] = tvm.make.Max(temp, score[tid])
-        with ib.if_scope(tvm.all(cls_id[tid] > 0, score[tid] < threshold)):
-            cls_id[tid] = 0
-        with ib.if_scope(cls_id[tid] > 0):
-            flag[tid] = 1
-        with ib.else_scope():
-            flag[tid] = 0
-    with ib.if_scope(tid < batch_size):
-        with ib.for_range(0, num_anchors, name="k") as k:
-            with ib.if_scope(k > 0):
-                flag[tid * num_anchors + k] += flag[tid * num_anchors + k - 1]
-        p_valid_count[tid] = flag[tid * num_anchors + num_anchors - 1]
-    with ib.if_scope(tid < batch_size * num_anchors):
-        n = tid / num_anchors
-        i = tid % num_anchors
+        n = tid / num_anchors  # number of batches
+        i = tid % num_anchors  # number of anchors
         with ib.if_scope(cls_id[tid] > 0):
             with ib.if_scope(tid == 0):
                 out_base_idx = n * num_anchors * 6
@@ -253,17 +311,17 @@
             p_out[out_base_idx] = cls_id[tid] - 1.0
             p_out[out_base_idx + 1] = score[tid]
             p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \
-            p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4, p_anchor, i*4,
-                                                    clip, variances[0], variances[1],
-                                                    variances[2], variances[3])
+                p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4,
+                                                        p_anchor, i*4, clip, variances[0],
+                                                        variances[1], variances[2], variances[3])
 
     body = ib.get()
     return body
 
 
 @multibox_transform_loc.register(["cuda", "gpu"])
-def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01,
-                               variances=(0.1, 0.1, 0.2, 0.2)):
+def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, \
+                               threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)):
     """Location transformation for multibox detection
 
     Parameters
@@ -297,20 +355,42 @@
         1-D tensor with shape (batch_size,), number of valid anchor boxes.
     """
     batch_size = cls_prob.shape[0]
-    num_anchors = anchor.shape[1]
+    num_classes = cls_prob.shape[1]
+    num_anchors = cls_prob.shape[2]
     oshape = (batch_size, num_anchors, 6)
     # Define data alignment for intermediate buffer
     valid_count_dtype = "int32"
     valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype,
                                       "valid_count_buf", data_alignment=4)
-    out_buf = api.decl_buffer(oshape, cls_prob.dtype, "out_buf", data_alignment=8)
-    valid_count, out = \
-        tvm.extern([(batch_size,), oshape],
-                   [cls_prob, loc_pred, anchor],
+    out_buf = api.decl_buffer(
+        oshape, cls_prob.dtype, "out_buf", data_alignment=8)
+    size = num_anchors
+    temp_flag_buf = api.decl_buffer(
+        (size,), valid_count_dtype, "flag", data_alignment=8)
+    temp_id_buf = api.decl_buffer(
+        (size,), valid_count_dtype, "cls_id", data_alignment=8)
+    temp_score_buf = api.decl_buffer(
+        (size,), cls_prob.dtype, "score", data_alignment=8)
+
+    valid_count, temp_flag, temp_id, temp_score = \
+        tvm.extern([(batch_size,), (size,), (size,), (size,)],
+                   [cls_prob],
+                   lambda ins, outs: transform_loc_pre(
+                       ins[0], outs[0], outs[1], outs[2], outs[3], threshold),
+                   dtype=[valid_count_dtype,
+                          valid_count_dtype, valid_count_dtype, cls_prob.dtype],
+                   out_buffers=[valid_count_buf,
+                                temp_flag_buf, temp_id_buf, temp_score_buf],
+                   tag="multibox_transform_loc_first_step")
+
+    out = \
+        tvm.extern([oshape],
+                   [loc_pred, anchor, temp_flag, temp_id, temp_score],
                    lambda ins, outs: transform_loc_ir(
-                       ins[0], ins[1], ins[2], outs[0], outs[1], clip, threshold, variances),
-                   dtype=[valid_count_dtype, cls_prob.dtype],
-                   out_buffers=[valid_count_buf, out_buf],
+                       ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, \
+                       variances, batch_size, num_classes, num_anchors),
+                   dtype=[cls_prob.dtype],
+                   out_buffers=[out_buf],
                    tag="multibox_transform_loc")
     return [out, valid_count]
 
@@ -356,5 +436,6 @@
     """
     inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
                                        clip, threshold, variances)
-    out = nms(inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk)
+    out = nms(
+        inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk)
     return out