Fixed bugs for SSD sorting and multbox detection (#1578)
diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py
index 4d4e402..361208b 100644
--- a/topi/python/topi/cuda/nms.py
+++ b/topi/python/topi/cuda/nms.py
@@ -7,19 +7,155 @@
from topi.vision import nms
-def sort_ir(data, index, output, axis, is_descend):
- """Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
+def sort_pre_ir(index, sizes_out, axis_mul_before, axis_mul_after):
+ """Low level IR routing subfunction 1/4 for computing segments' staring locatons.
+
+ Parameters
+ ----------
+ index : Buffer
+ Buffer of number of valid output boxes.
+
+ sizes_out : Buffer
+ Output buffer of start locations of each sorting segment.
+
+ axis_mul_before : int
+ The multiplication result of axis dimensions before axis.
+
+ axis_mul_after : int
+ The multiplication result of axis dimensions after axis.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ max_threads = int(
+ tvm.target.current_target(allow_none=False).max_num_threads)
+ tx = tvm.thread_axis("threadIdx.x")
+ bx = tvm.thread_axis("blockIdx.x")
+ ib = tvm.ir_builder.create()
+ p_index = ib.buffer_ptr(index)
+ dshape = sizes_out.shape
+ sizes = ib.buffer_ptr(sizes_out)
+ nthread_tx = max_threads
+ nthread_bx = dshape[0] // max_threads + 1
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ tid = bx * max_threads + tx
+
+ with ib.if_scope(tid < axis_mul_before * axis_mul_after):
+ sizes[tid] = p_index[tid]
+
+ # scan
+ with ib.if_scope(tid < 1):
+ with ib.for_range(0, axis_mul_before * axis_mul_after - 1, name="k") as k:
+ sizes[k + 1] += sizes[k]
+ body = ib.get()
+ return body
+
+
+def sort_pre_ir_data(data, index, sizes_in, data_out, index_out, \
+ axis, axis_mul_before, axis_mul_after):
+ """Low level IR routing subfunction 2/4 for flattening data and indices into segmented format.
Parameters
----------
data: Buffer
- 2D Buffer of input boxes' score with shape [batch_size, num_anchors].
+ Buffer of output boxes with class and score.
index : Buffer
- Buffer of number of valid number of boxes.
+ Buffer of number of valid output boxes.
- output : Buffer
- Output buffer of indicies of sorted tensor.
+ sizes_in : Buffer
+ Buffer of start locations of each sorting segment.
+
+ data_out : Buffer
+ Buffer of flattened segmented data.
+
+ index_out : Buffer
+ Buffer of flattened segmented indices.
+
+ axis : int
+ The axis used for sorting.
+
+ axis_mul_before : int
+ The multiplication result of axis dimensions before axis.
+
+ axis_mul_after : int
+ The multiplication result of axis dimensions after axis.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ ib = tvm.ir_builder.create()
+ sizes = ib.buffer_ptr(sizes_in)
+ p_index = ib.buffer_ptr(index)
+ p_data = ib.buffer_ptr(data)
+ data_new = ib.buffer_ptr(data_out)
+ index_new = ib.buffer_ptr(index_out)
+ max_threads = int(
+ tvm.target.current_target(allow_none=False).max_num_threads)
+ tx = tvm.thread_axis("threadIdx.x")
+ bx = tvm.thread_axis("blockIdx.x")
+ dshape = tvm.max(sizes_in.shape[0], p_index[0])
+ nthread_tx = max_threads
+ nthread_bx = dshape // max_threads + 1
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ tid = bx * max_threads + tx
+ with ib.if_scope(axis_mul_before * axis_mul_after > 1):
+ with ib.if_scope(tid < axis_mul_before * axis_mul_after):
+ i = tid / axis_mul_after
+ j = tid % axis_mul_after
+ current_sort_num = p_index[tid]
+ base_idx = i * data.shape[axis] * axis_mul_after + j
+ with ib.for_range(0, current_sort_num, name="k") as k:
+ full_idx = base_idx + k * axis_mul_after
+ with ib.if_scope(tid == 0):
+ start = 0
+ with ib.else_scope():
+ start = sizes[tid-1]
+ index_new[start + k] = k
+ data_new[start + k] = p_data[full_idx]
+ with ib.else_scope():
+ with ib.if_scope(tid == 0):
+ with ib.for_range(0, p_index[0], name="k") as k:
+ index_new[k] = k
+
+ body = ib.get()
+ return body
+
+def sort_oet_ir(data, index, new_data, new_index, loc, out_index, axis_mul_before, \
+ axis_mul_after, axis, is_descend):
+ """Low level IR routing subfunction 3/4 for Odd-Even-Transposition sorting.
+
+ Parameters
+ ----------
+ data: Buffer
+ Buffer of output boxes with class and score.
+
+ index : Buffer
+ Buffer of number of valid output boxes.
+
+ new_data : Buffer
+ Buffer of flattened segmented data.
+
+ new_index : Buffer
+ Buffer of flattened segmented indices.
+
+ loc : Buffer
+ Buffer of start locations of each sorting segment.
+
+ out_index : Buffer
+ Output buffer of output box indexes sorted by score in a flattened segmented format.
+
+ axis_mul_before : int
+ The multiplication result of axis dimensions before axis.
+
+ axis_mul_after : int
+ The multiplication result of axis dimensions after axis.
axis : int
The axis used for sorting.
@@ -32,15 +168,197 @@
stmt : Stmt
The result IR statement.
"""
-
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create()
+ dshape = loc.shape
+ fshape = data.shape[axis] * dshape[0]
+ temp_data = ib.allocate(
+ "float32", dshape, name="temp_data", scope="local")
p_data = ib.buffer_ptr(data)
p_index = ib.buffer_ptr(index)
+ data_new = ib.buffer_ptr(new_data)
+ index_new = ib.buffer_ptr(new_index)
+ index_out = ib.buffer_ptr(out_index)
+ sizes = ib.buffer_ptr(loc)
+ nthread_tx = max_threads
+ nthread_bx = fshape // max_threads + 1
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ tid = bx * max_threads + tx
+
+ with ib.if_scope(axis_mul_before * axis_mul_after > 1):
+ with ib.if_scope(tid < axis_mul_before * axis_mul_after):
+ with ib.if_scope(tid == 0):
+ start = 0
+ with ib.else_scope():
+ start = sizes[tid-1]
+ # OddEvenTransposeSort
+ with ib.for_range(0, p_index[tid], name="k") as k:
+ with ib.for_range(0, p_index[tid] - 1, name="i") as i:
+ with ib.if_scope(i % 2 == k % 2):
+ with ib.if_scope(((data_new[i+start] < data_new[i+start+1]) == is_descend)):
+ temp_data[tid] = data_new[i+start]
+ data_new[i+start] = data_new[i+start+1]
+ data_new[i+start+1] = temp_data[tid]
+ index_out[tid] = index_new[i+start]
+ index_new[i+start] = index_new[i+start+1]
+ index_new[i+start+1] = index_out[tid]
+ with ib.if_scope(tid < 1):
+ with ib.for_range(0, sizes[dshape[0] - 1], name="i") as i:
+ index_out[i] = index_new[i]
+ with ib.else_scope():
+ with ib.for_range(0, fshape, name="k", for_type="unroll") as k:
+ with ib.if_scope(tvm.all(k % 2 == tid % 2, tid < fshape)):
+ with ib.if_scope(k % 2 == 0):
+ with ib.if_scope(tvm.all(tid + 1 < fshape, (p_data[tid] < p_data[tid+1]) \
+ == is_descend)):
+ data_new[tid] = p_data[tid+1]
+ index_out[tid] = index_new[tid+1]
+ with ib.else_scope():
+ data_new[tid] = p_data[tid]
+ index_out[tid] = index_new[tid]
+ with ib.else_scope():
+ with ib.if_scope(tvm.all(tid + 1 < fshape, (data_new[tid] < data_new[tid+1]) \
+ == is_descend)):
+ p_data[tid] = data_new[tid+1]
+ index_new[tid] = index_out[tid+1]
+ with ib.else_scope():
+ p_data[tid] = data_new[tid]
+ index_new[tid] = index_out[tid]
+ with ib.if_scope(tvm.all(k % 2 != tid % 2, tid < fshape)):
+ with ib.if_scope(k % 2 == 0):
+ with ib.if_scope(tvm.all(tid > 0, (p_data[tid-1] < p_data[tid]) == is_descend)):
+ data_new[tid] = p_data[tid-1]
+ index_out[tid] = index_new[tid-1]
+ with ib.else_scope():
+ data_new[tid] = p_data[tid]
+ index_out[tid] = index_new[tid]
+ with ib.else_scope():
+ with ib.if_scope(tvm.all(tid > 0, (data_new[tid-1] < data_new[tid]) \
+ == is_descend)):
+ p_data[tid] = data_new[tid-1]
+ index_new[tid] = index_out[tid-1]
+ with ib.else_scope():
+ p_data[tid] = data_new[tid]
+ index_new[tid] = index_out[tid]
+ with ib.if_scope(fshape % 2 == 1):
+ with ib.if_scope(tid < 1):
+ with ib.for_range(0, fshape, name="k") as k:
+ index_out[tid] = index_new[tid]
+ body = ib.get()
+ return body
+
+
+def sort_ir_out(data, index, new_index, loc, output, axis_mul_before, axis_mul_after, axis):
+ """Low level IR routing subfunction 4/4 for writing sorted indices to output format.
+
+ Parameters
+ ----------
+ data: Buffer
+ Buffer of output boxes with class and score.
+
+ index : Buffer
+ Buffer of number of valid output boxes.
+
+ new_index : Buffer
+ Buffer of sorted indices in a flatten format.
+
+ loc : Buffer
+ Buffer of start locations of each sorting segment.
+
+ output : Buffer
+ Output buffer of output box indexes sorted by score.
+
+ axis_mul_before : int
+ The multiplication result of axis dimensions before axis.
+
+ axis_mul_after : int
+ The multiplication result of axis dimensions after axis.
+
+ axis : int
+ The axis used for sorting.
+
+ is_descend : bool
+ If the sorted data is in descending order.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ max_threads = int(
+ tvm.target.current_target(allow_none=False).max_num_threads)
+ tx = tvm.thread_axis("threadIdx.x")
+ bx = tvm.thread_axis("blockIdx.x")
+ ib = tvm.ir_builder.create()
+ dshape = tvm.max(loc.shape[0], data.shape[axis])
+ p_index = ib.buffer_ptr(index)
+ index_new = ib.buffer_ptr(new_index)
+ sizes = ib.buffer_ptr(loc)
p_out = ib.buffer_ptr(output)
+ nthread_tx = max_threads
+ nthread_bx = dshape // max_threads + 1
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ tid = bx * max_threads + tx
+
+ with ib.if_scope(axis_mul_before * axis_mul_after > 1):
+ with ib.if_scope(tid < axis_mul_before * axis_mul_after):
+ i = tid / axis_mul_after
+ j = tid % axis_mul_after
+ base_idx = i * data.shape[axis] * axis_mul_after + j
+ with ib.for_range(0, data.shape[axis], name="k") as k:
+ with ib.if_scope(tid == 0):
+ start = 0
+ with ib.else_scope():
+ start = sizes[tid-1]
+ p_out[base_idx + k * axis_mul_after] = tvm.select(
+ k < p_index[tid], index_new[k+start], k)
+ with ib.else_scope():
+ with ib.if_scope(tid < data.shape[axis]):
+ p_out[tid] = tvm.select(tid < p_index[0], index_new[tid], tid)
+
+ body = ib.get()
+ return body
+
+
+def sort_gpu(data, data_buf, index, index_buf, output_buf, axis, is_descend):
+ """Function to generate low level IR to do sorting on the GPU, use it by calling sort_gpu.
+
+ Parameters
+ ----------
+ data: tvm.Tensor
+ 3-D tensor with shape [batch_size, num_anchors, 6].
+ The last dimension should be in format of
+ [class_id, score, box_left, box_top, box_right, box_bottom].
+
+ data_buf: Buffer
+ 2D Buffer of input boxes' score with shape [batch_size, num_anchors].
+
+ index : tvm.Tensor
+ 1-D tensor for valid number of boxes.
+
+ index_buf : Buffer
+ Buffer of number of valid number of boxes.
+
+ output_buf : Buffer
+ Output buffer of indicies of sorted tensor.
+
+ axis : int
+ The axis used for sorting.
+
+ is_descend : bool
+ If the sorted data is in descending order.
+
+ Returns
+ -------
+ out : tvm.Tensor
+ 3-D tensor with shape [batch_size, num_anchors].
+ """
+
ndim = len(data.shape)
assert data.dtype == "float32", "Currently only supports input dtype to be float32"
assert axis < ndim, "Axis out of boundary for input ndim %d" % ndim
@@ -55,89 +373,60 @@
elif i > axis:
axis_mul_after *= data.shape[i]
- dshape = 0
- for i in range(0, len(index.shape)):
- dshape += index.shape[i]
- dshape = tvm.select(dshape > axis_mul_before*axis_mul_after, dshape,
- axis_mul_before*axis_mul_after)
+ dshape = axis_mul_before*axis_mul_after
+ fshape = data.shape[axis] * dshape
- sizes_temp = ib.allocate(
- "int32", dshape, name="sizes_temp", scope="global")
- sizes = ib.allocate("int32", dshape, name="sizes", scope="global")
- temp_index = ib.allocate("int32", dshape, name="temp_index", scope="local")
- temp_data = ib.allocate("float32", dshape, name="temp_data", scope="local")
- data_new = ib.allocate("float32", dshape, name="data_new", scope="global")
- index_new = ib.allocate("int32", dshape, name="index_new", scope="global")
- nthread_tx = max_threads
- nthread_bx = dshape // max_threads + 1
- ib.scope_attr(tx, "thread_extent", nthread_tx)
- ib.scope_attr(bx, "thread_extent", nthread_bx)
- tid = bx * max_threads + tx
+ loc_buf = api.decl_buffer(dshape, index.dtype, "sizes", data_alignment=8)
+ new_index_buf = api.decl_buffer(
+ fshape, index.dtype, "index_new", data_alignment=8)
+ out_index_buf = api.decl_buffer(
+ fshape, index.dtype, "index_out", data_alignment=8)
+ new_data_buf = api.decl_buffer(
+ dshape, data.dtype, "data_new", data_alignment=8)
- with ib.if_scope(tid < axis_mul_before * axis_mul_after):
- sizes[tid] = p_index[tid]
- sizes_temp[tid] = p_index[tid]
+ loc = \
+ tvm.extern([(dshape,)],
+ [index],
+ lambda ins, outs: sort_pre_ir(
+ ins[0], outs[0], axis_mul_before, axis_mul_after),
+ dtype=[index.dtype],
+ in_buffers=index_buf,
+ out_buffers=[loc_buf],
+ tag="sorting_prepare")
- with ib.if_scope(tid < axis_mul_before * axis_mul_after):
- with ib.for_range(0, tvm.floor(tvm.sqrt((axis_mul_before * axis_mul_after) \
- .astype("float32"))) + 1, name="k") as k:
- with ib.if_scope(tid - (tvm.const(1, "int32") << k) >= 0):
- with ib.if_scope(k % 2 == 0):
- sizes[tid] += sizes_temp[tid - (
- tvm.const(1, "int32") << k)]
- sizes_temp[tid] = sizes[tid]
- with ib.else_scope():
- sizes_temp[tid] += sizes[tid - (
- tvm.const(1, "int32") << k)]
- sizes[tid] = sizes_temp[tid]
+ data_new, index_new = \
+ tvm.extern([(dshape,), (fshape,)],
+ [data, index, loc],
+ lambda ins, outs: sort_pre_ir_data(
+ ins[0], ins[1], ins[2], outs[0], outs[1], axis,
+ axis_mul_before, axis_mul_after),
+ dtype=[data.dtype, index.dtype],
+ in_buffers=[data_buf, index_buf, loc_buf],
+ out_buffers=[new_data_buf, new_index_buf],
+ tag="sorting_data")
- with ib.if_scope(tid < axis_mul_before * axis_mul_after):
- i = tid / axis_mul_after
- j = tid % axis_mul_after
- current_sort_num = p_index[tid]
- base_idx = i * data.shape[axis] * axis_mul_after + j
- with ib.for_range(0, current_sort_num, name="k") as k:
- full_idx = base_idx + k * axis_mul_after
- with ib.if_scope(tid == 0):
- start = 0
- with ib.else_scope():
- start = sizes[tid-1]
- index_new[start + k] = k
- data_new[start + k] = p_data[full_idx]
-
- with ib.if_scope(tid < axis_mul_before * axis_mul_after):
- with ib.if_scope(tid == 0):
- start = 0
- with ib.else_scope():
- start = sizes[tid-1]
- # OddEvenTransposeSort
- with ib.for_range(0, p_index[tid], name="k") as k:
- with ib.for_range(0, p_index[tid] - 1, name="i") as i:
- with ib.if_scope(i % 2 == (k & 1)):
- with ib.if_scope(((data_new[i+start] < data_new[i+start+1]) ^
- is_descend) == False):
- temp_data[tid] = data_new[i+start]
- data_new[i+start] = data_new[i+start+1]
- data_new[i+start+1] = temp_data[tid]
- temp_index[tid] = index_new[i+start]
- index_new[i+start] = index_new[i+start+1]
- index_new[i+start+1] = temp_index[tid]
-
- with ib.if_scope(tid < axis_mul_before * axis_mul_after):
- i = tid / axis_mul_after
- j = tid % axis_mul_after
- current_sort_num = p_index[tid]
- base_idx = i * data.shape[axis] * axis_mul_after + j
- with ib.for_range(0, data.shape[axis], name="k") as k:
- with ib.if_scope(tid == 0):
- start = 0
- with ib.else_scope():
- start = sizes[tid-1]
- p_out[base_idx + k * axis_mul_after] = tvm.select(
- k < current_sort_num,
- index_new[k+start], k)
- body = ib.get()
- return body
+ index_out = \
+ tvm.extern([(fshape,)],
+ [data, index, data_new, index_new, loc],
+ lambda ins, outs: sort_oet_ir(
+ ins[0], ins[1], ins[2], ins[3], ins[4], outs[0],
+ axis_mul_before, axis_mul_after, axis, is_descend),
+ dtype=[index.dtype],
+ in_buffers=[data_buf, index_buf,
+ new_data_buf, new_index_buf, loc_buf],
+ out_buffers=[out_index_buf],
+ tag="sorting_oet")
+ out = \
+ tvm.extern([data.shape],
+ [data, index, index_out, loc],
+ lambda ins, outs: sort_ir_out(
+ ins[0], ins[1], ins[2], ins[3], outs[0],
+ axis_mul_before, axis_mul_after, axis),
+ dtype=[index.dtype],
+ in_buffers=[data_buf, index_buf, out_index_buf, loc_buf],
+ out_buffers=output_buf,
+ tag="sorting_output")
+ return out
def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk):
@@ -333,15 +622,8 @@
sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype,
"sort_tensor_buf", data_alignment=8)
- sort_tensor = \
- tvm.extern(score_shape,
- [score_tensor, valid_count],
- lambda ins, outs: sort_ir(
- ins[0], ins[1], outs[0], score_axis, True),
- dtype=sort_tensor_dtype,
- in_buffers=[score_tensor_buf, valid_count_buf],
- out_buffers=sort_tensor_buf,
- name="nms_sort")
+ sort_tensor = sort_gpu(score_tensor, score_tensor_buf, valid_count,
+ valid_count_buf, sort_tensor_buf, score_axis, True)
out = \
tvm.extern(data.shape,
[data, sort_tensor, valid_count],
diff --git a/topi/python/topi/cuda/ssd/multibox.py b/topi/python/topi/cuda/ssd/multibox.py
index c22e7a5..3c013c4 100644
--- a/topi/python/topi/cuda/ssd/multibox.py
+++ b/topi/python/topi/cuda/ssd/multibox.py
@@ -1,4 +1,4 @@
-# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, too-many-function-args
"""SSD multibox operators"""
from __future__ import absolute_import as _abs
import math
@@ -13,6 +13,7 @@
from topi.vision.ssd import multibox_transform_loc
from ..nms import nms
+
def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
"""Low level IR routing for multibox_prior operator.
@@ -41,7 +42,8 @@
stmt : Stmt
The result IR statement.
"""
- max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads))
+ max_threads = int(math.sqrt(
+ tvm.target.current_target(allow_none=False).max_num_threads))
tx = tvm.thread_axis("threadIdx.x")
ty = tvm.thread_axis("threadIdx.y")
bx = tvm.thread_axis("blockIdx.x")
@@ -76,7 +78,8 @@
for k in range(num_sizes + num_ratios - 1):
w = tvm.select(k < num_sizes,
- size_ratio_concat[k] * in_height / in_width / 2.0,
+ size_ratio_concat[
+ k] * in_height / in_width / 2.0,
size_ratio_concat[0] * in_height / in_width *
math.sqrt(size_ratio_concat[k + 1]) / 2.0)
h = tvm.select(k < num_sizes, size_ratio_concat[k] / 2.0,
@@ -93,7 +96,7 @@
@multibox_prior.register(["cuda", "gpu"])
-def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1), \
+def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1),
offsets=(0.5, 0.5), clip=False):
"""Generate prior(anchor) boxes from data, sizes and ratios.
@@ -124,31 +127,114 @@
"""
num_sizes = len(sizes)
num_ratios = len(ratios)
- oshape = (1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4)
+ oshape = (
+ 1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4)
out = tvm.extern(oshape, [data], lambda ins, outs:
- multibox_prior_ir(ins[0], outs[0], sizes, ratios, steps, offsets),
+ multibox_prior_ir(
+ ins[0], outs[0], sizes, ratios, steps, offsets),
tag="multibox_prior")
if clip:
out = topi.clip(out, 0, 1)
return out
-def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, threshold, variances):
- """Low level IR routing for transform location in multibox_detection operator.
+def transform_loc_pre(cls_prob, valid_count, temp_flag, temp_id, temp_score_out, threshold):
+ """Low level IR routing for transform location data preparation.
Parameters
----------
cls_prob : Buffer
Buffer of class probabilities.
+ valid_count : Buffer
+ Buffer of number of valid output boxes.
+
+ temp_flag : Buffer
+ Output intermediate result buffer
+
+ temp_id : Buffer
+ Output intermediate result buffer
+
+ temp_score_out : Buffer
+ Output buffer
+
+ threshold : float
+ Threshold to be a positive prediction.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ batch_size = cls_prob.shape[0]
+ num_classes = cls_prob.shape[1]
+ num_anchors = cls_prob.shape[2]
+
+ max_threads = int(
+ tvm.target.current_target(allow_none=False).max_num_threads)
+ ib = tvm.ir_builder.create()
+ score = ib.buffer_ptr(temp_score_out)
+ cls_id = ib.buffer_ptr(temp_id)
+ flag = ib.buffer_ptr(temp_flag)
+ tx = tvm.thread_axis("threadIdx.x")
+ bx = tvm.thread_axis("blockIdx.x")
+ nthread_tx = max_threads
+ nthread_bx = (batch_size * num_anchors * num_classes) // max_threads + 1
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ tid = bx * max_threads + tx
+ p_cls_prob = ib.buffer_ptr(cls_prob)
+ p_valid_count = ib.buffer_ptr(valid_count)
+
+ with ib.if_scope(tid < batch_size * num_anchors):
+ n = tid / num_anchors # number of batches
+ i = tid % num_anchors # number of anchors
+ score[i] = -1.0
+ cls_id[i] = 0
+ p_valid_count[n] = 0
+ with ib.for_range(0, num_classes-1, name="k") as k:
+ temp = p_cls_prob[n * num_anchors * num_classes + (k + 1) * num_anchors + i]
+ with ib.if_scope(temp > score[i]):
+ cls_id[i] = k + 1
+ score[i] = temp
+ with ib.if_scope(tvm.all(cls_id[i] > 0, score[i] < threshold)):
+ cls_id[i] = 0
+ with ib.if_scope(cls_id[i] > 0):
+ flag[i] = 1
+ with ib.else_scope():
+ flag[i] = 0
+
+ with ib.if_scope(tid < batch_size):
+ with ib.for_range(0, num_anchors, name="k") as k:
+ with ib.if_scope(k > 0):
+ flag[tid * num_anchors +
+ k] += flag[tid * num_anchors + k - 1]
+ p_valid_count[n] = flag[tid * num_anchors + num_anchors - 1]
+
+ body = ib.get()
+ return body
+
+
+def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
+ out, clip, variances, batch_size, num_classes, num_anchors):
+ """Low level IR routing for transform location in multibox_detection operator.
+
+ Parameters
+ ----------
loc_pred : Buffer
Buffer of location regression predictions.
anchor : Buffer
Buffer of prior anchor boxes.
- valid_count : Buffer
- Buffer of number of valid output boxes.
+ temp_flag : Buffer
+ Intermediate result buffer.
+
+ temp_id : Buffer
+ Intermediate result buffer.
+
+ temp_score_in : Buffer
+ Input buffer which stores intermediate results.
out : Buffer
Output buffer.
@@ -156,12 +242,18 @@
clip : boolean
Whether to clip out-of-boundary boxes.
- threshold : float
- Threshold to be a positive prediction.
-
variances : tuple of float
Variances to be decoded from box regression output.
+ batch_size : int
+ Batch size
+
+ num_classes : int
+ Number of classes
+
+ num_anchors : int
+ Number of anchors
+
Returns
-------
stmt : Stmt
@@ -187,21 +279,16 @@
ow = tvm.exp(pw * vw) * aw / 2.0
oh = tvm.exp(ph * vh) * ah / 2.0
return tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox - ow)), ox - ow), \
- tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \
- tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \
- tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh)
+ tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \
+ tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \
+ tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh)
- batch_size = cls_prob.shape[0]
- num_classes = cls_prob.shape[1]
- num_anchors = cls_prob.shape[2]
-
+ max_threads = int(
+ tvm.target.current_target(allow_none=False).max_num_threads)
ib = tvm.ir_builder.create()
- temp_score = ib.allocate('float32', (batch_size * (num_classes -1) * num_anchors, \
- ), name="temp_score", scope="global")
- score = ib.allocate('float32', (batch_size * num_anchors, ), name="score", scope="local")
- cls_id = ib.allocate('int32', (batch_size * num_anchors, ), name="id", scope="local")
- flag = ib.allocate('int32', (batch_size * num_anchors, ), name="flag", scope="global")
- max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+ score = ib.buffer_ptr(temp_score_in)
+ cls_id = ib.buffer_ptr(temp_id)
+ flag = ib.buffer_ptr(temp_flag)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
nthread_tx = max_threads
@@ -209,42 +296,13 @@
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
- p_cls_prob = ib.buffer_ptr(cls_prob)
p_loc_pred = ib.buffer_ptr(loc_pred)
p_anchor = ib.buffer_ptr(anchor)
- p_valid_count = ib.buffer_ptr(valid_count)
p_out = ib.buffer_ptr(out)
- with ib.if_scope(tid < batch_size * num_anchors * num_classes):
- n = tid / (num_anchors * num_classes)
- j = (tid % (num_anchors * num_classes)) / num_anchors
- i = tid % num_anchors
- with ib.if_scope(j > 0):
- temp_score[n * num_anchors * num_classes + i * (num_classes - 1) + j-1] = \
- p_cls_prob[tid]
- p_valid_count[n] = 0
+
with ib.if_scope(tid < batch_size * num_anchors):
- n = tid / num_anchors
- i = tid % num_anchors
- score[tid] = -1.0
- cls_id[tid] = 0
- with ib.for_range(0, num_classes-1, name="k") as k:
- temp = temp_score[tid * (num_classes-1) + k]
- cls_id[tid] = tvm.select(temp > score[tid], k + 1, cls_id[tid])
- score[tid] = tvm.make.Max(temp, score[tid])
- with ib.if_scope(tvm.all(cls_id[tid] > 0, score[tid] < threshold)):
- cls_id[tid] = 0
- with ib.if_scope(cls_id[tid] > 0):
- flag[tid] = 1
- with ib.else_scope():
- flag[tid] = 0
- with ib.if_scope(tid < batch_size):
- with ib.for_range(0, num_anchors, name="k") as k:
- with ib.if_scope(k > 0):
- flag[tid * num_anchors + k] += flag[tid * num_anchors + k - 1]
- p_valid_count[tid] = flag[tid * num_anchors + num_anchors - 1]
- with ib.if_scope(tid < batch_size * num_anchors):
- n = tid / num_anchors
- i = tid % num_anchors
+ n = tid / num_anchors # number of batches
+ i = tid % num_anchors # number of anchors
with ib.if_scope(cls_id[tid] > 0):
with ib.if_scope(tid == 0):
out_base_idx = n * num_anchors * 6
@@ -253,17 +311,17 @@
p_out[out_base_idx] = cls_id[tid] - 1.0
p_out[out_base_idx + 1] = score[tid]
p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \
- p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4, p_anchor, i*4,
- clip, variances[0], variances[1],
- variances[2], variances[3])
+ p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4,
+ p_anchor, i*4, clip, variances[0],
+ variances[1], variances[2], variances[3])
body = ib.get()
return body
@multibox_transform_loc.register(["cuda", "gpu"])
-def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01,
- variances=(0.1, 0.1, 0.2, 0.2)):
+def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, \
+ threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)):
"""Location transformation for multibox detection
Parameters
@@ -297,20 +355,42 @@
1-D tensor with shape (batch_size,), number of valid anchor boxes.
"""
batch_size = cls_prob.shape[0]
- num_anchors = anchor.shape[1]
+ num_classes = cls_prob.shape[1]
+ num_anchors = cls_prob.shape[2]
oshape = (batch_size, num_anchors, 6)
# Define data alignment for intermediate buffer
valid_count_dtype = "int32"
valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype,
"valid_count_buf", data_alignment=4)
- out_buf = api.decl_buffer(oshape, cls_prob.dtype, "out_buf", data_alignment=8)
- valid_count, out = \
- tvm.extern([(batch_size,), oshape],
- [cls_prob, loc_pred, anchor],
+ out_buf = api.decl_buffer(
+ oshape, cls_prob.dtype, "out_buf", data_alignment=8)
+ size = num_anchors
+ temp_flag_buf = api.decl_buffer(
+ (size,), valid_count_dtype, "flag", data_alignment=8)
+ temp_id_buf = api.decl_buffer(
+ (size,), valid_count_dtype, "cls_id", data_alignment=8)
+ temp_score_buf = api.decl_buffer(
+ (size,), cls_prob.dtype, "score", data_alignment=8)
+
+ valid_count, temp_flag, temp_id, temp_score = \
+ tvm.extern([(batch_size,), (size,), (size,), (size,)],
+ [cls_prob],
+ lambda ins, outs: transform_loc_pre(
+ ins[0], outs[0], outs[1], outs[2], outs[3], threshold),
+ dtype=[valid_count_dtype,
+ valid_count_dtype, valid_count_dtype, cls_prob.dtype],
+ out_buffers=[valid_count_buf,
+ temp_flag_buf, temp_id_buf, temp_score_buf],
+ tag="multibox_transform_loc_first_step")
+
+ out = \
+ tvm.extern([oshape],
+ [loc_pred, anchor, temp_flag, temp_id, temp_score],
lambda ins, outs: transform_loc_ir(
- ins[0], ins[1], ins[2], outs[0], outs[1], clip, threshold, variances),
- dtype=[valid_count_dtype, cls_prob.dtype],
- out_buffers=[valid_count_buf, out_buf],
+ ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, \
+ variances, batch_size, num_classes, num_anchors),
+ dtype=[cls_prob.dtype],
+ out_buffers=[out_buf],
tag="multibox_transform_loc")
return [out, valid_count]
@@ -356,5 +436,6 @@
"""
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances)
- out = nms(inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk)
+ out = nms(
+ inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk)
return out