blob: 24b46f9a88729237fbecc83d93c847c81e9ec101 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument, no-else-return
"""Sort related operators """
import tvm
from tvm import te
from .injective import schedule_injective_from_existing
from ..transform import strided_slice, transpose
from .. import tag
from ..utils import ceil_div, swap
from ..math import cast, ceil_log2
def _schedule_sort(outs):
"""Schedule for argsort operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of argsort
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
if tag.is_injective(op.tag):
schedule_injective_from_existing(s, op.output(0))
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)
for out in outs:
traverse(out.op)
return s
def _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz):
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
by = te.thread_axis("blockIdx.y")
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)
return tx, bx, by, bz
def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_func=None):
"""Initialize the output buffers by copying from inputs"""
axis_mul_before = 1
axis_mul_after = 1
if axis < 0:
axis = len(shape) + axis
for i, value in enumerate(shape, 0):
if i < axis:
axis_mul_before *= value
elif i > axis:
axis_mul_after *= value
# Set up threading
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = ceil_div(shape[axis], max_threads)
nthread_by = axis_mul_before
nthread_bz = axis_mul_after
# Copy the keys_in to initial output
with ib.new_scope():
tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
tid = bx * nthread_tx + tx
idx = (by * shape[axis] + tid) * axis_mul_after + bz
with ib.if_scope(tid < shape[axis]):
keys_out[idx] = keys_in[idx]
if values_out is not None:
values_out[idx] = value_init_func(idx, tid)
return axis_mul_before, axis_mul_after
## TODO(mbrookhart): These are effective optimziation hyperparametrs
## Perhaps we can autotune?
block_size = 128
thread_work = 4
def _odd_even_sort(
ib,
size,
axis_mul_before,
axis_mul_after,
is_ascend,
keys,
keys_swap,
values=None,
values_swap=None,
):
nthread_tx = block_size // 2
nthread_bx = ceil_div(size, block_size)
nthread_by = axis_mul_before
nthread_bz = axis_mul_after
with ib.new_scope():
ib.scope_attr(tvm.tir.const(0), "hand_threaded", 0)
tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
tid = 2 * tx
start = bx * block_size
## Create shared memory as syncable thread scratch space
tmp_keys_swap = ib.allocate(
keys_swap.dtype,
(block_size,),
name="temp_keys_swap",
scope="shared",
)
if values_swap is not None:
tmp_values_swap = ib.allocate(
values_swap.dtype,
(block_size,),
name="temp_values_swap",
scope="shared",
)
## Create thread local data for swapping
temp_keys = ib.allocate(keys_swap.dtype, (1,), name="temp_keys", scope="local")
if values_swap is not None:
temp_values = ib.allocate(values_swap.dtype, (1,), name="temp_values", scope="local")
temp_cond1 = ib.allocate(keys_swap.dtype, (1,), name="temp_cond1", scope="local")
temp_cond2 = ib.allocate(keys_swap.dtype, (1,), name="temp_cond2", scope="local")
# Copy data to scratch space
base_idx = by * size * axis_mul_after + bz
with ib.for_range(0, 2) as n:
with ib.if_scope((tid + n + start) < size):
tmp_keys_swap[tid + n] = keys[base_idx + (tid + n + start) * axis_mul_after]
if values_swap is not None:
tmp_values_swap[tid + n] = values[base_idx + (tid + n + start) * axis_mul_after]
ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
idxm = tvm.tir.indexmod
# OddEvenTransposeSort
current_sort_num = tvm.tir.min(block_size, size - start)
with ib.for_range(0, current_sort_num) as k:
n = idxm(tid + k, 2)
with ib.if_scope(tid + n < current_sort_num - 1):
temp_cond1[0] = tmp_keys_swap[tid + n]
temp_cond2[0] = tmp_keys_swap[tid + n + 1]
if is_ascend:
cond = temp_cond1[0] > temp_cond2[0]
else:
cond = temp_cond1[0] < temp_cond2[0]
with ib.if_scope(cond):
temp_keys[0] = tmp_keys_swap[tid + n]
tmp_keys_swap[tid + n] = tmp_keys_swap[tid + n + 1]
tmp_keys_swap[tid + n + 1] = temp_keys[0]
if values_swap is not None:
temp_values[0] = tmp_values_swap[tid + n]
tmp_values_swap[tid + n] = tmp_values_swap[tid + n + 1]
tmp_values_swap[tid + n + 1] = temp_values[0]
ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
## Copy sorted data to output
with ib.for_range(0, 2) as n:
with ib.if_scope(tid + n + start < size):
keys[base_idx + (tid + n + start) * axis_mul_after] = tmp_keys_swap[tid + n]
keys_swap[base_idx + (tid + n + start) * axis_mul_after] = tmp_keys_swap[tid + n]
if values_swap is not None:
values[base_idx + (tid + n + start) * axis_mul_after] = tmp_values_swap[tid + n]
values_swap[base_idx + (tid + n + start) * axis_mul_after] = tmp_values_swap[
tid + n
]
def _sort_common(
ib,
size,
axis_mul_before,
axis_mul_after,
is_ascend,
keys,
keys_swap,
values=None,
values_swap=None,
):
"""Either sort only values or sort values by keys."""
## This function performs a multi-level mergesort
## For blocks of length <= block_size, it does odd-even transpose sort
## in GPU shared memory
## For intermediate block sizes (>block_size, < max_threads * thread_work)
## it uses the mergpath algorthim https://arxiv.org/abs/1406.2628
## to merge blocks in parallel
## At some point, the size of the blocks to be merged is too big for max_threads
## and we switch to using a dual-level mergepath where the outer mergepath
## finds the start/end locations of the inner mergepath so that we can split
## the merge into more blocks
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_by = axis_mul_before * axis_mul_after
nthread_bz = 1
nthread_tx = max_threads
nthread_bx = ceil_div(size, nthread_tx)
def compare(a, b):
"""
Compare a and b in proper ascending or descending order
"""
if is_ascend:
out = a <= b
else:
out = b <= a
return out
# Sort the lower levels of the merge using odd-even sort, it's fast for small inputs
lower_lim = ceil_log2(block_size)
_odd_even_sort(
ib,
size,
axis_mul_before * axis_mul_after,
1,
is_ascend,
keys,
keys_swap,
values,
values_swap,
)
upper_lim = ceil_log2(size)
def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count):
first = ib.allocate("int64", (1,), name="first", scope="local")
mid = ib.allocate("int64", (1,), name="mid", scope="local")
last = ib.allocate("int64", (1,), name="last", scope="local")
first[0] = tvm.te.max(0, diag - bCount)
last[0] = tvm.te.min(diag, aCount)
with ib.while_loop(first[0] < last[0]):
mid = (first[0] + last[0]) >> 1
a = source[base_idx + (aStart + mid)]
b = source[base_idx + (bStart + diag - 1 - mid)]
with ib.if_scope(compare(a, b)):
first[0] = mid + 1
with ib.else_scope():
last[0] = mid
return first[0], last[0]
def serial_merge(
source,
dest,
source_idx,
dest_idx,
base_idx,
aCount,
bCount,
aStart,
bStart,
kStart,
diag,
step_count,
first,
last,
):
i = ib.allocate("int64", (1,), name="i", scope="local")
j = ib.allocate("int64", (1,), name="j", scope="local")
i[0] = aStart + first
j[0] = bStart + diag - last
with ib.for_range(0, tvm.te.min(aCount + bCount - diag, step_count)) as count:
i_idx = base_idx + i[0]
j_idx = base_idx + j[0]
k_idx = base_idx + (kStart + diag + count)
def assign_i():
"""assign i value to current output"""
dest[k_idx] = source[i_idx]
if values is not None:
dest_idx[k_idx] = source_idx[i_idx]
i[0] += 1
def assign_j():
"""assign j value to current output"""
dest[k_idx] = source[j_idx]
if values is not None:
dest_idx[k_idx] = source_idx[j_idx]
j[0] += 1
## if both of the iterators are in range
with ib.if_scope(tvm.tir.all(i[0] < aStart + aCount, j[0] < bStart + bCount)):
# compare them and insert whichever is next into the output
with ib.if_scope(compare(source[i_idx], source[j_idx])):
assign_i()
with ib.else_scope():
assign_j()
# otherwise, simply copy the remainder of the valid iterator to the output
with ib.else_scope():
with ib.if_scope(i[0] < aStart + aCount):
assign_i()
with ib.else_scope():
assign_j()
with ib.for_range(0, cast(upper_lim - lower_lim, "int64"), dtype="int64") as l2_width:
width = 2 << (l2_width + lower_lim)
# Define and launch the cuda kernel
with ib.new_scope():
target = tvm.target.Target.current()
if "vulkan" in str(target):
# Vulkan can't handle dynamic nthread, so we thread slightly differently
# for vulkan. We don't do this generally because it causes a 15% perf
# regression on other platforms
ntx = max_threads
nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32")
nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
else:
ntx = tvm.tir.generic.cast(tvm.te.min(max_threads, width), "int32")
nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32")
nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
def mergepath(
source,
dest,
source_idx,
dest_idx,
aCount,
bCount,
aStart,
bStart,
kStart,
step_count,
even,
):
# pylint: disable=arguments-out-of-order
def merge(source, dest, source_idx, dest_idx):
diag = tx * step_count
first, last = get_merge_begin(
source,
by * size,
aCount,
bCount,
aStart,
bStart,
diag,
step_count,
)
# iterate over the output loop
serial_merge(
source,
dest,
source_idx,
dest_idx,
by * size,
aCount,
bCount,
aStart,
bStart,
kStart,
diag,
step_count,
first,
last,
)
with ib.if_scope(even):
merge(source, dest, source_idx, dest_idx)
with ib.else_scope():
merge(dest, source, dest_idx, source_idx)
def mergesort(source, dest, source_idx, dest_idx, size, width, even):
# calculate the start, mid, and end points of this section
start = width * bz
middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), size), "int64")
end = cast(tvm.te.min(start + width, size), "int64")
with ib.if_scope(start < size):
with ib.if_scope(nbx == 1):
## merge the start->middle and middle->end arrays
aCount = middle - start
bCount = end - middle
mergepath(
source,
dest,
source_idx,
dest_idx,
aCount,
bCount,
start,
middle,
start,
ceil_div(width, ntx),
even,
)
with ib.else_scope():
step_count = max_threads * thread_work
diag = bx * step_count
def do_merge(first, last):
aStart = start + first
bStart = middle + diag - last
aCount = tvm.te.min(middle - aStart, step_count)
bCount = tvm.te.min(end - bStart, step_count)
mergepath(
source,
dest,
source_idx,
dest_idx,
aCount,
bCount,
aStart,
bStart,
start + diag,
thread_work,
even,
)
with ib.if_scope(even):
first, last = get_merge_begin(
source,
by * size,
middle - start,
end - middle,
start,
middle,
diag,
step_count,
)
do_merge(first, last)
with ib.else_scope():
first, last = get_merge_begin(
dest,
by * size,
middle - start,
end - middle,
start,
middle,
diag,
step_count,
)
do_merge(first, last)
# Call the kernel
mergesort(
keys,
keys_swap,
values,
values_swap,
size,
width,
tvm.tir.indexmod(l2_width, 2) == 0,
)
nthread_by = axis_mul_before
nthread_bz = axis_mul_after
nthread_tx = max_threads
nthread_bx = ceil_div(size, nthread_tx)
## if the final sorted data ended up in the swap, copy it to the real output
with ib.if_scope(
tvm.tir.all(upper_lim > lower_lim, tvm.tir.indexmod(upper_lim - lower_lim, 2) == 1)
):
with ib.new_scope():
tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
tid = bx * nthread_tx + tx
idx = (by * axis_mul_after + bz) * size + tid
with ib.if_scope(tid < size):
keys[idx] = keys_swap[idx]
if values is not None:
values[idx] = values_swap[idx]
def sort_ir(
data, values_out, values_out_swap, axis, is_ascend, indices_out=None, indices_out_swap=None
):
"""Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
Parameters
----------
data: Buffer
Buffer of input data. Data will be sorted in place.
values_out : Buffer
Output buffer of values of sorted tensor with same shape as data.
values_out_swap : Buffer
Output buffer of values with same shape as data to use as swap.
axis : Int
Axis long which to sort the input tensor.
is_ascend : Boolean
Whether to sort in ascending or descending order.
indicess_out : Buffer
Output buffer of indices of sorted tensor with same shape as data.
indices_out_swap : Buffer
Output buffer of indices with same shape as data to use as swap.
Returns
-------
stmt : Stmt
The result IR statement.
"""
ib = tvm.tir.ir_builder.create()
shape = data.shape
data = ib.buffer_ptr(data)
values_out = ib.buffer_ptr(values_out)
values_out_swap = ib.buffer_ptr(values_out_swap)
if indices_out is not None:
indices_out = ib.buffer_ptr(indices_out)
assert indices_out_swap is not None
indices_out_swap = ib.buffer_ptr(indices_out_swap)
with ib.if_scope(shape[axis] > 0):
axis_mul_before, axis_mul_after = _sort_init(
ib,
shape,
axis,
data,
values_out,
indices_out,
value_init_func=lambda _, tid: tvm.tir.generic.cast(tid, indices_out.dtype),
)
_sort_common(
ib,
shape[axis],
axis_mul_before,
axis_mul_after,
is_ascend,
values_out,
values_out_swap,
values=indices_out,
values_swap=indices_out_swap,
)
return ib.get()
def sort_by_key_ir(
keys_in, values_in, keys_out, values_out, keys_out_swap, values_out_swap, axis, is_ascend
):
"""Low level IR to do sort by key on the GPU.
Parameters
----------
keys_in: Buffer
Buffer of input keys.
values_in: Buffer
Buffer of input keys.
keys_out : Buffer
Buffer of output sorted keys.
values_out : Buffer
Buffer of output sorted values.
keys_out_swap : Buffer
Output buffer of values with same shape as keys_in to use as swap.
values_out_swap : Buffer
Output buffer of values with same shape as values_in to use as swap.
axis : Int
Axis long which to sort the input tensor.
is_ascend : Boolean
Whether to sort in ascending or descending order.
indicess_out : Buffer
Output buffer of indices of sorted tensor with same shape as keys_in.
values_out_swap : Buffer
Output buffer of indices with same shape as keys_in to use as swap.
Returns
-------
stmt : Stmt
The result IR statement.
"""
ib = tvm.tir.ir_builder.create()
shape = keys_in.shape
keys_in = ib.buffer_ptr(keys_in)
values_in = ib.buffer_ptr(values_in)
keys_out = ib.buffer_ptr(keys_out)
keys_out_swap = ib.buffer_ptr(keys_out_swap)
values_out = ib.buffer_ptr(values_out)
values_out_swap = ib.buffer_ptr(values_out_swap)
with ib.if_scope(shape[axis] > 0):
axis_mul_before, axis_mul_after = _sort_init(
ib,
shape,
axis,
keys_in,
keys_out,
values_out,
value_init_func=lambda idx, _: values_in[idx],
)
_sort_common(
ib,
shape[axis],
axis_mul_before,
axis_mul_after,
is_ascend,
keys_out,
keys_out_swap,
values=values_out,
values_swap=values_out_swap,
)
return ib.get()
def sort(data, axis=-1, is_ascend=1):
"""Performs sorting along the given axis and returns an array of
sorted values with the same shape as the input data.
Parameters
----------
data: tvm.te.Tensor
The input array.
axis : int, optional
Axis long which to sort the input tensor.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
Returns
-------
out : tvm.te.Tensor
The output of this function.
"""
ndim = len(data.shape)
axis = ndim + axis if axis < 0 else axis
if axis != ndim - 1:
# Prepare for sorting along axis -1.
axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)
value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
value_buf_swap = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf_swap", data_alignment=8)
out = te.extern(
[data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], -1, is_ascend),
out_buffers=[value_buf, value_buf_swap],
name="sort_gpu",
tag="sort_gpu",
)[0]
if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
out = transpose(out, axes)
return out
def sort_thrust(data, axis=-1, is_ascend=1):
"""Performs sorting along the given axis and returns an array of
sorted values with the same shape as the input data.
Parameters
----------
data: tvm.te.Tensor
The input array.
axis : int, optional
Axis long which to sort the input tensor.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
Returns
-------
out : tvm.te.Tensor
The output of this function.
"""
dtype = "float32"
ndim = len(data.shape)
axis = ndim + axis if axis < 0 else axis
if axis != ndim - 1:
# Prepare for sorting along axis -1.
axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)
value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
out = te.extern(
[data.shape, data.shape],
[data],
## TODO(mbrookhart): This thrust function is actually doing argsort, not sort
## For performance, we should probably rename the contrib function and add
## a pure sort
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend
),
out_buffers=[value_buf, indices_buf],
name="sort_gpu",
tag="sort_gpu",
)[0]
if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
out = transpose(out, axes)
return out
def argsort(data, axis=-1, is_ascend=1, dtype="float32", ret_type="indices"):
"""Performs sorting along the given axis and returns an array of indices
having same shape as an input array that index data in sorted order.
Parameters
----------
data: tvm.te.Tensor
The input array.
axis : int, optional
Axis long which to sort the input tensor.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
DType of the output indices.
ret_type : string, optional
The return type [both, indices].
"both": return both sorted data and indices.
"indices": return sorted indices only.
Returns
-------
out : tvm.te.Tensor
The output of this function.
"""
ndim = len(data.shape)
axis = ndim + axis if axis < 0 else axis
if axis != ndim - 1:
# Prepare for sorting along axis -1.
axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)
value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
value_swap_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_swap_buf", data_alignment=8)
indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8)
outs = te.extern(
[data.shape, data.shape, data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(
ins[0],
outs[0],
outs[2],
-1,
is_ascend,
indices_out=outs[1],
indices_out_swap=outs[3],
),
out_buffers=[value_buf, indices_buf, value_swap_buf, indices_swap_buf],
name="argsort_gpu",
tag="argsort_gpu",
)
if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
outs = [transpose(out, axes) for out in outs]
if ret_type == "indices":
return outs[1]
return outs[0], outs[1]
def argsort_thrust(data, axis=-1, is_ascend=1, dtype="float32", ret_type="indices"):
"""Performs sorting along the given axis and returns an array of indices
having same shape as an input array that index data in sorted order.
Parameters
----------
data: tvm.te.Tensor
The input array.
axis : int, optional
Axis long which to sort the input tensor.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
DType of the output indices.
ret_type : string, optional
The return type [both, indices].
"both": return both sorted data and indices.
"indices": return sorted indices only.
Returns
-------
out : tvm.te.Tensor
The output of this function.
"""
return topk_thrust(data, 0, axis, ret_type, is_ascend, dtype)
def schedule_sort(outs):
"""Schedule for sort operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of argsort
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _schedule_sort(outs)
def schedule_argsort(outs):
"""Schedule for argsort operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of argsort
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _schedule_sort(outs)
def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
"""Get the top k elements in an input tensor along the given axis.
Parameters
----------
data : tvm.te.Tensor
The input tensor.
k : int, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Axis long which to sort the input tensor.
ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
The data type of the indices output.
Returns
-------
out : tvm.te.Tensor or List[tvm.te.Tensor]
The computed result.
"""
assert ret_type in ["both", "values", "indices"]
ndim = len(data.shape)
axis = axis + ndim if axis < 0 else axis
assert 0 <= axis < ndim
dshape = data.shape
if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)
values_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8)
values_swap_buf = tvm.tir.decl_buffer(
data.shape, data.dtype, "values_swap_buf", data_alignment=8
)
indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "indies_swap_buf", data_alignment=8)
if ret_type == "values":
output = te.extern(
[data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], -1, is_ascend),
out_buffers=[values_buf, values_swap_buf],
name="topk_gpu",
tag="topk_gpu",
)[0]
if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
output = transpose(output, axes)
else:
output = te.extern(
[data.shape, data.shape, data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(
ins[0],
outs[0],
outs[2],
-1,
is_ascend,
indices_out=outs[1],
indices_out_swap=outs[3],
),
out_buffers=[values_buf, indices_buf, values_swap_buf, indices_swap_buf],
name="topk_gpu",
tag="topk_gpu",
)[0:2]
if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
output[0] = transpose(output[0], axes)
output[1] = transpose(output[1], axes)
if isinstance(k, int) and k < 1:
if ret_type == "indices":
return output[1]
return output
beg = [0] * ndim
end = []
strides = [1] * ndim
for i in range(ndim):
if i == axis:
end.append(k if isinstance(k, int) else tvm.te.size_var("dim"))
else:
end.append(dshape[i])
if ret_type == "both":
values_out, indices_out = output
values_out = strided_slice(values_out, beg, end, strides)
indices_out = strided_slice(indices_out, beg, end, strides)
output = [values_out, indices_out]
elif ret_type == "values":
output = [strided_slice(output, beg, end, strides)]
else: # ret_type == "indices"
indices_out = output[1]
output = [strided_slice(indices_out, beg, end, strides)]
return output
def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
"""Get the top k elements in an input tensor along the given axis.
Parameters
----------
data : tvm.te.Tensor
The input tensor.
k : int, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Axis long which to sort the input tensor.
ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
The data type of the indices output.
Returns
-------
out : tvm.te.Tensor or List[tvm.te.Tensor]
The computed result.
"""
assert ret_type in ["both", "values", "indices"]
ndim = len(data.shape)
axis = ndim + axis if axis < 0 else axis
if axis != ndim - 1:
# Prepare for sorting along axis -1.
axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
out_bufs = [
tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8),
tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8),
]
is_ascend = 1 if is_ascend else 0
out = te.extern(
[data.shape, data.shape],
[data],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend
),
in_buffers=[data_buf],
out_buffers=out_bufs,
name="topk_gpu",
tag="topk_gpu",
)
if isinstance(k, tvm.tir.IntImm):
k = k.value
if not isinstance(k, int) or k > 0:
beg = [0] * ndim
end = data.shape[:-1] + [k if isinstance(k, int) else tvm.te.size_var("dim")]
strides = [1] * ndim
out = [strided_slice(o, beg, end, strides) for o in out]
if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
out = [transpose(o, axes) for o in out]
if ret_type == "values":
out = out[0]
elif ret_type == "indices":
out = out[1]
return out
def schedule_topk(outs):
"""Schedule for argsort operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of argsort
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _schedule_sort(outs)
def sort_by_key(keys, values, axis=-1, is_ascend=1):
"""Sort values with respect to keys. Both keys and values will
be sorted and returned.
Parameters
----------
keys: tvm.te.Tensor
The input keys.
values : tvm.te.Tensor,
The input values.
axis : int, optional
Axis long which to sort the input tensor.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
Returns
-------
keys_sorted : tvm.te.Tensor
The sorted keys
values_sorted : tvm.te.Tensor
The values sorted with respect to the keys
"""
keys_buf = tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8)
values_buf = tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf", data_alignment=8)
out_bufs = [
tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8),
tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf", data_alignment=8),
tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_swap_buf", data_alignment=8),
tvm.tir.decl_buffer(values.shape, values.dtype, "values_swap_buf", data_alignment=8),
]
out = te.extern(
[keys.shape, values.shape, keys.shape, values.shape],
[keys, values],
lambda ins, outs: sort_by_key_ir(
ins[0], ins[1], outs[0], outs[1], outs[2], outs[3], axis, is_ascend
),
in_buffers=[keys_buf, values_buf],
out_buffers=out_bufs,
dtype=[keys.dtype, values.dtype],
name="sort_by_key",
tag="sort_by_key",
)
return out[0], out[1]
def stable_sort_by_key_thrust(keys, values, for_scatter=False):
"""Sort values with respect to keys using thrust.
Both keys and values will be sorted and returned.
Sorting is done via stable sort, so relative ordering among
ties are preserved.
Parameters
----------
keys: tvm.te.Tensor
The 1D input keys.
values : tvm.te.Tensor,
The 1D input values.
for_scatter: bool, optional
If True, negative keys are interpreted as negative indices.
Before sorting, negative indices are converted to corresponding positive indices.
The output keys (indices) are all positive.
This option is introduced to optimize the scatter implementation.
Returns
-------
keys_sorted : tvm.te.Tensor
The sorted keys
values_sorted : tvm.te.Tensor
The values sorted with respect to the keys
"""
keys_buf = tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8)
values_buf = tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf", data_alignment=8)
out_bufs = [
tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8),
tvm.tir.decl_buffer(keys.shape, values.dtype, "values_buf", data_alignment=8),
]
out = te.extern(
[keys.shape, values.shape],
[keys, values],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.stable_sort_by_key", ins[0], ins[1], outs[0], outs[1], for_scatter
),
in_buffers=[keys_buf, values_buf],
out_buffers=out_bufs,
dtype=[keys.dtype, values.dtype],
name="stable_sort_by_key",
tag="stable_sort_by_key",
)
return out[0], out[1]