blob: 5b7e342c4b4ed22c1c45bfcdbe06207affbbe7f2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Backend compiler related feature registration"""
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks,
# pylint: disable=too-many-local-variables, too-many-arguments, no-else-return
from __future__ import absolute_import
import tvm
from tvm import te, topi
from tvm.runtime import convert
from tvm.te.hybrid import script
from tvm.topi.utils import get_const_int, get_const_tuple
from . import op as _reg
from . import strategy
from ._tensor import elemwise_shape_func
from .op import OpPattern
_reg.register_broadcast_schedule("broadcast_to")
_reg.register_broadcast_schedule("broadcast_to_like")
_reg.register_broadcast_schedule("expand_dims")
_reg.register_broadcast_schedule("repeat")
_reg.register_broadcast_schedule("tile")
_reg.register_broadcast_schedule("where")
_reg.register_injective_schedule("squeeze")
_reg.register_injective_schedule("reshape")
_reg.register_injective_schedule("reshape_like")
_reg.register_injective_schedule("full")
_reg.register_injective_schedule("full_like")
_reg.register_injective_schedule("arange")
_reg.register_injective_schedule("meshgrid")
_reg.register_injective_schedule("reverse")
_reg.register_injective_schedule("reverse_sequence")
_reg.register_injective_schedule("cast")
_reg.register_injective_schedule("cast_like")
_reg.register_injective_schedule("reinterpret")
_reg.register_injective_schedule("strided_slice")
_reg.register_injective_schedule("slice_like")
_reg.register_injective_schedule("split")
_reg.register_injective_schedule("take")
_reg.register_injective_schedule("stack")
_reg.register_injective_schedule("contrib_reverse_reshape")
_reg.register_injective_schedule("gather")
_reg.register_injective_schedule("gather_nd")
_reg.register_injective_schedule("sequence_mask")
_reg.register_injective_schedule("one_hot")
_reg.register_reduce_schedule("collapse_sum_like")
_reg.register_reduce_schedule("collapse_sum_to")
_reg.register_injective_schedule("unravel_index")
_reg.register_injective_schedule("sparse_to_dense")
_reg.register_injective_schedule("matrix_set_diag")
_reg.register_injective_schedule("adv_index")
# concatenate
@_reg.register_compute("concatenate")
def compute_concat(attrs, inputs, output_type):
return [topi.concatenate(inputs, attrs.axis)]
_reg.register_strategy("concatenate", strategy.concatenate_strategy)
# sliding_window
@_reg.register_compute("sliding_window")
def compute_sliding_window(attrs, inputs, output_type):
"""Compute definition of sliding_window"""
return [topi.sliding_window(inputs[0], attrs.axis, attrs.window_shape, attrs.strides)]
_reg.register_strategy("sliding_window", strategy.sliding_window_strategy)
# strided_set
@_reg.register_compute("strided_set")
def compute_strided_set(attrs, inputs, output_type):
"""Compute definition of strided_set"""
return [topi.strided_set(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4])]
_reg.register_injective_schedule("strided_set")
# layout_transform
_reg.register_injective_schedule("layout_transform")
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
_reg.register_injective_schedule("auto_scheduler_layout_transform")
_reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE)
_reg.register_injective_schedule("meta_schedule_layout_transform")
_reg.register_pattern("meta_schedule_layout_transform", OpPattern.INJECTIVE)
# argwhere
_reg.register_strategy("argwhere", strategy.argwhere_strategy)
# scatter
@_reg.register_compute("scatter")
def compute_scatter(attrs, inputs, output_type):
"""Compute definition of scatter"""
return [topi.scatter(inputs[0], inputs[1], inputs[2], attrs.axis)]
_reg.register_strategy("scatter", strategy.scatter_strategy)
# sparse_fill_empty_rows
@_reg.register_compute("sparse_fill_empty_rows")
def compute_sparse_fill_empty_rows(attrs, inputs, output_type):
"""Compute definition of sparse_fill_empty_rows"""
return topi.sparse_fill_empty_rows(
inputs[0],
inputs[1],
inputs[2],
inputs[3],
output_type.fields[0].shape,
output_type.fields[1].shape,
output_type.fields[2].shape,
)
_reg.register_strategy("sparse_fill_empty_rows", strategy.sparse_fill_empty_rows_strategy)
# sparse_reshape
@_reg.register_compute("sparse_reshape")
def compute_reshape(attrs, inputs, output_type):
"""Compute definition of sparse_reshape"""
return topi.sparse_reshape(
inputs[0],
inputs[1],
inputs[2],
output_type.fields[0].shape,
output_type.fields[1].shape,
)
_reg.register_strategy("sparse_reshape", strategy.sparse_reshape_strategy)
# stft
@_reg.register_compute("stft")
def compute_stft(attrs, inputs, output_type):
"""Compute definition of stft"""
return topi.stft(
inputs[0],
attrs.n_fft,
attrs.hop_length,
attrs.win_length,
attrs.window,
attrs.normalized,
attrs.onesided,
output_type.shape,
)
_reg.register_strategy("stft", strategy.stft_strategy)
@script
def _stft_shape_func(data, n_fft, hop_length, onesided):
output_shape = output_tensor((4,), "int64")
output_shape[0] = int64(data.shape[0])
if onesided:
output_shape[1] = int64(int64(n_fft) // int64(2)) + int64(1)
else:
output_shape[1] = int64(n_fft)
output_shape[2] = int64(int64(data.shape[1] - n_fft) // int64(hop_length)) + int64(1)
output_shape[3] = int64(2)
return output_shape
@_reg.register_shape_func("stft", True)
def stft_shape_func(attrs, inputs, _):
"""
Shape func for stft.
"""
return [
_stft_shape_func(
inputs[0], convert(attrs.n_fft), convert(attrs.hop_length), convert(attrs.onesided)
)
]
# trilu
_reg.register_strategy("trilu", strategy.trilu_strategy)
# scatter_add
@_reg.register_compute("scatter_add")
def compute_scatter_add(attrs, inputs, output_type):
"""Compute definition of scatter_add"""
return [topi.scatter_add(inputs[0], inputs[1], inputs[2], attrs.axis)]
_reg.register_strategy("scatter_add", strategy.scatter_add_strategy)
# scatter_nd
@_reg.register_compute("scatter_nd")
def compute_scatter_nd(attrs, inputs, output_type):
"""Compute definition of scatter_nd"""
return [topi.scatter_nd(inputs[0], inputs[1], inputs[2], attrs.mode)]
_reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy)
# cumsum
@_reg.register_compute("cumsum")
def compute_cumsum(attrs, inputs, output_type):
"""Compute definition of cumsum"""
return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]
_reg.register_strategy("cumsum", strategy.cumsum_strategy)
_reg.register_shape_func("cumsum", False, elemwise_shape_func)
# cumprod
@_reg.register_compute("cumprod")
def compute_cumprod(attrs, inputs, output_type):
"""Compute definition of cumprod"""
return [topi.cumprod(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]
_reg.register_strategy("cumprod", strategy.cumprod_strategy)
_reg.register_shape_func("cumprod", False, elemwise_shape_func)
@_reg.register_compute("unique")
def compute_unique(attrs, inputs, output_type):
"""Compute definition of unique"""
return topi.unique(inputs[0], attrs.sorted, attrs.return_counts)
_reg.register_strategy("unique", strategy.unique_strategy)
# invert_permutation
_reg.register_strategy("invert_permutation", strategy.invert_permutation_strategy)
_reg.register_shape_func("invert_permutation", False, elemwise_shape_func)
#####################
# Shape functions #
#####################
@script
def _arange_shape_func(start, stop, step):
out = output_tensor((1,), "int64")
if step[()] < 0:
out[0] = int64(ceil_div((int64(start[()]) - int64(stop[()])), int64(-step[()])))
else:
out[0] = int64(ceil_div((int64(stop[()]) - int64(start[()])), int64(step[()])))
return out
@_reg.register_shape_func("arange", True)
def arange_shape_func(attrs, inputs, _):
"""
Shape func for arange
"""
return [_arange_shape_func(*inputs)]
@script
def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice_mode):
ndim = len(data_shape)
out = output_tensor((ndim,), "int64")
for i in const_range(ndim):
dim_size = int64(data_shape[i])
cbegin = int64(0)
cend = dim_size
cstride = int64(1)
if len(strides) > i:
cstride = int64(strides[i])
if len(begin) > i:
cbegin = int64(begin[i])
elif cstride < 0:
cbegin = dim_size
if len(end) <= i:
if cstride < 0:
cend = int64(0)
elif slice_mode != 0:
cstride = int64(1)
if end[i] < 0:
cend = dim_size
else:
cend = cbegin + int64(end[i])
else:
if end[i] > data_shape[i]:
cend = dim_size
else:
cend = int64(end[i])
assert cstride != 0, "Strides can't be zero."
if cbegin < 0:
cbegin += dim_size
if cend < 0:
cend += dim_size
if cstride < 0:
if cend < 0:
cend = int64(-1)
if cbegin > dim_size - 1:
cbegin = dim_size - 1
slice_range = cbegin - cend
step = -cstride
else:
slice_range = cend - cbegin
step = cstride
out[i] = int64(ceil_div(slice_range, step))
return out
@script
def _strided_slice_shape_func_with_axes(data_shape, begin, end, strides, slice_mode, axes):
ndim = data_shape.shape[0]
out = output_tensor((ndim,), "int64")
for i in const_range(ndim):
out[i] = data_shape[i]
for i in const_range(len(axes)):
dim_size = int64(data_shape[axes[i]])
cbegin = int64(0)
cend = dim_size
cstride = int64(1)
if len(strides) > i:
cstride = int64(strides[i])
if len(begin) > i:
cbegin = int64(begin[i])
elif cstride < 0:
cbegin = dim_size
if len(end) <= i:
cend = dim_size
elif slice_mode != 0:
cstride = int64(1)
if end[i] < 0:
cend = dim_size
else:
cend = cbegin + int64(end[i])
else:
if end[i] > data_shape[axes[i]]:
cend = dim_size
else:
cend = int64(end[i])
assert cstride != 0, "Strides can't be zero."
if cbegin < 0:
cbegin += dim_size
if cend < 0:
cend += dim_size
if cstride < 0:
if cend < 0:
cend = int64(-1)
if cbegin > dim_size - 1:
cbegin = dim_size - 1
slice_range = cbegin - cend
step = -cstride
else:
slice_range = cend - cbegin
step = cstride
out[axes[i]] = int64(ceil_div(slice_range, step))
return out
@_reg.register_shape_func("strided_slice", False)
def strided_slice_shape_func(attrs, inputs, _):
"""
Shape func for strided_slice
"""
slice_mode = convert(0 if attrs.slice_mode == "end" else 1)
if attrs.axes is None:
return [
_strided_slice_shape_func_input_shape(
inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode
)
]
return [
_strided_slice_shape_func_with_axes(
inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode, attrs.axes
)
]
@script
def _one_hot_shape_func(indices_shape, depth, axis):
in_ndim = indices_shape.shape[0]
out_ndim = in_ndim + 1
true_axis = in_ndim if axis == -1 else axis
indices_i = 0
out = output_tensor((out_ndim,), "int64")
for i in range(out_ndim):
if i == true_axis:
out[i] = int64(depth)
else:
out[i] = int64(indices_shape[indices_i])
indices_i += 1
return out
@_reg.register_shape_func("one_hot", False)
def one_hot_shape_func(attrs, inputs, _):
"""
Shape func for one_hot
"""
shape_func = [_one_hot_shape_func(inputs[0], convert(attrs.depth), convert(attrs.axis))]
return shape_func
@script
def _concatenate_shape_func(inputs, axis):
ndim = inputs[0].shape[0]
out = output_tensor((ndim,), "int64")
for i in const_range(ndim):
if i != axis:
out[i] = inputs[0][i]
for j in const_range(1, len(inputs)):
assert out[i] == inputs[j][i], "Dims mismatch in the inputs of concatenate."
else:
out[i] = int64(0)
for j in const_range(len(inputs)):
out[i] += inputs[j][i]
return out
@_reg.register_shape_func("concatenate", False)
def concatenate_shape_func(attrs, inputs, _):
axis = get_const_int(attrs.axis)
if axis < 0:
axis += inputs[0].shape[0]
return [_concatenate_shape_func(inputs, convert(axis))]
@script
def _reshape_shape_func_input_shape(data_shape, newshape, ndim, allowzero):
out = output_tensor((ndim,), "int64")
src_idx = 0
dst_idx = 0
infer_idx = -1
copy = False
skip = 0
for i in const_range(len(newshape)):
if skip > 0:
skip -= 1
elif newshape[i] > 0:
out[dst_idx] = int64(newshape[i])
src_idx += 1
dst_idx += 1
elif newshape[i] == 0:
if allowzero:
out[dst_idx] = int64(newshape[i])
else:
out[dst_idx] = data_shape[src_idx]
src_idx += 1
dst_idx += 1
elif newshape[i] == -1:
assert infer_idx < 0, "One and only one dim can be inferred"
out[dst_idx] = int64(1)
infer_idx = i
src_idx += 1
dst_idx += 1
elif newshape[i] == -2:
copy = True
elif newshape[i] == -3:
assert data_shape.shape[0] - src_idx > 1, "Not enough dims in input shape for -3"
out[dst_idx] = data_shape[src_idx] * data_shape[src_idx + 1]
src_idx += 2
dst_idx += 1
elif newshape[i] == -4:
assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
if newshape[i + 1] == -1:
assert newshape[i + 2] != -1, "Split dims cannot both be -1."
out[dst_idx] = data_shape[src_idx] // int64(newshape[i + 2])
out[dst_idx + 1] = int64(newshape[i + 2])
else:
out[dst_idx] = int64(newshape[i + 1])
if newshape[i + 2] == -1:
out[dst_idx + 1] = data_shape[src_idx] // int64(newshape[i + 1])
else:
out[dst_idx + 1] = int64(newshape[i + 2])
assert (
data_shape[src_idx] == out[dst_idx] * out[dst_idx + 1]
), "Product of split dims doesn't match to input dim"
src_idx += 1
dst_idx += 2
skip = 2
else:
assert False, "Invalid special values in new shape"
if len(data_shape.shape) > 0:
# if data is not constant, we can then handle -1 and -2
if copy:
for i in range(src_idx, data_shape.shape[0]):
out[dst_idx] = data_shape[i]
dst_idx += 1
if infer_idx >= 0:
old_size = int64(1)
for i in const_range(data_shape.shape[0]):
old_size *= data_shape[i]
new_size = int64(1)
for i in const_range(out.shape[0]):
new_size *= out[i]
out[infer_idx] = old_size // new_size
return out
@_reg.register_shape_func("reshape", False)
def reshape_shape_func(attrs, inputs, out_ndims):
newshape = get_const_tuple(attrs.newshape)
allowzero = attrs.allowzero
return [
_reshape_shape_func_input_shape(
inputs[0], convert(newshape), out_ndims[0], convert(allowzero)
)
]
@script
def _take_no_axis_shape_func(indices_shape, out_ndim):
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = indices_shape[i]
return out
@script
def _take_with_axis_shape_func(data_shape, indices_shape, axis, batch_dims, out_ndim):
out = output_tensor((out_ndim,), "int64")
for i in const_range(axis):
out[i] = data_shape[i]
if len(indices_shape.shape) == 0:
# indices is constant
for i in const_range(axis + 1, len(data_shape)):
out[i - 1] = data_shape[i]
else:
for i in const_range(len(indices_shape) - batch_dims):
out[axis + i] = indices_shape[i + batch_dims]
for i in const_range(axis + 1, len(data_shape)):
out[len(indices_shape) + i - 1 - batch_dims] = data_shape[i]
return out
@_reg.register_shape_func("take", False)
def take_shape_func(attrs, inputs, out_ndims):
"""
Shape function for take op.
"""
if attrs.axis is None:
return [_take_no_axis_shape_func(inputs[1], out_ndims[0])]
axis = get_const_int(attrs.axis)
batch_dims = get_const_int(attrs.batch_dims)
data_ndim = int(inputs[0].shape[0])
if inputs[1].shape:
indices_ndim = int(inputs[1].shape[0])
if axis < 0:
axis += data_ndim
assert 0 <= axis < data_ndim
if batch_dims < 0:
batch_dims += indices_ndim
return [_take_with_axis_shape_func(*inputs, convert(axis), convert(batch_dims), out_ndims[0])]
@_reg.register_legalize("take")
def legalize_dyn_topk(attrs, inputs, types):
"""Legalize take op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current op
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.take_legalize(attrs, inputs, types)
@script
def _argwhere_shape_func_1d(condition):
out = output_tensor((2,), "int64")
out[0] = int64(0)
out[1] = int64(1)
for i1 in range(condition.shape[0]):
if condition[i1] != 0:
out[0] += int64(1)
return out
@script
def _argwhere_shape_func_2d(condition):
out = output_tensor((2,), "int64")
out[0] = int64(0)
out[1] = int64(2)
for i1 in range(condition.shape[0]):
for i2 in range(condition.shape[1]):
if condition[i1, i2] != 0:
out[0] += int64(1)
return out
@script
def _argwhere_shape_func_3d(condition):
out = output_tensor((2,), "int64")
out[0] = int64(0)
out[1] = int64(3)
for i1 in range(condition.shape[0]):
for i2 in range(condition.shape[1]):
for i3 in range(condition.shape[2]):
if condition[i1, i2, i3] != 0:
out[0] += int64(1)
return out
@script
def _argwhere_shape_func_4d(condition):
out = output_tensor((2,), "int64")
out[0] = int64(0)
out[1] = int64(4)
for i1 in range(condition.shape[0]):
for i2 in range(condition.shape[1]):
for i3 in range(condition.shape[2]):
for i4 in range(condition.shape[3]):
if condition[i1, i2, i3, i4] != 0:
out[0] += int64(1)
return out
@script
def _argwhere_shape_func_5d(condition):
out = output_tensor((2,), "int64")
out[0] = int64(0)
out[1] = int64(5)
for i1 in range(condition.shape[0]):
for i2 in range(condition.shape[1]):
for i3 in range(condition.shape[2]):
for i4 in range(condition.shape[3]):
for i5 in range(condition.shape[4]):
if condition[i1, i2, i3, i4, i5] != 0:
out[0] += int64(1)
return out
@_reg.register_shape_func("argwhere", True)
def argwhere_shape_func(attrs, inputs, out_ndims):
"""
Shape function for argwhere.
"""
if len(inputs[0].shape) == 1:
return [_argwhere_shape_func_1d(inputs[0])]
if len(inputs[0].shape) == 2:
return [_argwhere_shape_func_2d(inputs[0])]
if len(inputs[0].shape) == 3:
return [_argwhere_shape_func_3d(inputs[0])]
if len(inputs[0].shape) == 4:
return [_argwhere_shape_func_4d(inputs[0])]
if len(inputs[0].shape) == 5:
return [_argwhere_shape_func_5d(inputs[0])]
return ValueError("Does not support rank higher than 5 in argwhere")
_reg.register_shape_func("scatter", False, elemwise_shape_func)
_reg.register_shape_func("scatter_add", False, elemwise_shape_func)
_reg.register_shape_func("scatter_nd", False, elemwise_shape_func)
@script
def _sparse_fill_empty_rows_shape_func(sparse_indices, dense_shape):
new_sparse_indices_shape = output_tensor((2,), "int64")
new_sparse_values_shape = output_tensor((1,), "int64")
empty_row_indicator_shape = output_tensor((1,), "int64")
num_dense_rows = int64(dense_shape[0])
if int64(sparse_indices.shape[0]) == int64(0): # Handle Empty Case
# Total rows will equal dense_shape[0]
new_sparse_indices_shape[0] = num_dense_rows
new_sparse_indices_shape[1] = int64(sparse_indices.shape[1])
new_sparse_values_shape[0] = num_dense_rows
empty_row_indicator_shape[0] = num_dense_rows
return (new_sparse_indices_shape, new_sparse_values_shape, empty_row_indicator_shape)
else:
count = int64(sparse_indices.shape[0]) # Add count of all rows already in sparse_indices
for i in range(1, int64(sparse_indices.shape[0])):
index = int64(sparse_indices[i, 0])
prev_index = int64(sparse_indices[i - 1, 0] + 1)
if index > prev_index:
count += index - prev_index # Add count of all rows between two consecutive indices
count += int64(sparse_indices[0, 0]) # Add count from 0 to first row id in sparse_indices
count += int64(
num_dense_rows - 1 - sparse_indices[sparse_indices.shape[0] - 1, 0]
) # Add count from last row id to dense_shape - 1
new_sparse_indices_shape[0] = int64(count)
new_sparse_indices_shape[1] = int64(sparse_indices.shape[1])
new_sparse_values_shape[0] = int64(count)
empty_row_indicator_shape[0] = num_dense_rows
return (new_sparse_indices_shape, new_sparse_values_shape, empty_row_indicator_shape)
@_reg.register_shape_func("sparse_fill_empty_rows", True)
def sparse_fill_empty_rows_func(attrs, inputs, _):
return _sparse_fill_empty_rows_shape_func(inputs[0], inputs[2])
@script
def _sparse_reshape_shape_func(sparse_indices_shape, prev_shape_shape, new_shape_shape):
indices_shape = output_tensor((2,), "int64")
indices_shape[0] = int64(sparse_indices_shape[0])
indices_shape[1] = int64(new_shape_shape[0])
shape_tensor = output_tensor((1,), "int64")
shape_tensor[0] = int64(new_shape_shape[0])
return (indices_shape, shape_tensor)
@_reg.register_shape_func("sparse_reshape", False)
def sparse_reshape_shape_func(attrs, inputs, _):
"""
Shape func for sparse_reshape.
"""
return _sparse_reshape_shape_func(inputs[0], inputs[1], inputs[2])
@script
def _layout_transform_shape_func(
data_shape, out_layout_len, dst_equal_list, dst_mul_list, dst_div_list, dst_mix_list
):
out = output_tensor((out_layout_len,), "int64")
for i in const_range(len(dst_equal_list)):
out[dst_equal_list[i][0]] = data_shape[dst_equal_list[i][1]]
for i in const_range(len(dst_mul_list)):
out[dst_mul_list[i][0]] = data_shape[dst_mul_list[i][1]] * data_shape[dst_mul_list[i][2]]
for i in const_range(len(dst_div_list)):
out[dst_div_list[i][0]] = data_shape[dst_div_list[i][1]] // dst_div_list[i][3]
out[dst_div_list[i][2]] = int64(dst_div_list[i][3])
for i in const_range(len(dst_mix_list)):
out[dst_mix_list[i][0]] = (
data_shape[dst_mix_list[i][1]] * dst_mix_list[i][2] // dst_mix_list[i][4]
)
out[dst_mix_list[i][3]] = int64(dst_mix_list[i][4])
return out
@_reg.register_shape_func("layout_transform", False)
def layout_transform_shape_func(attrs, inputs, _):
"""
Shape function for layout_transform op.
"""
def _fetch_axis(layout):
major_axes = []
minor_axes = {}
num_start = -1
for i, item in enumerate(layout):
if "A" <= item <= "Z":
major_axes.append(item)
elif "a" <= item <= "z":
last_num = int(layout[num_start:i])
minor_axes[item] = last_num
num_start = -1
elif num_start < 0:
num_start = i
return major_axes, minor_axes
_, src_minor_axes = _fetch_axis(attrs.src_layout)
dst_major_axes, dst_minor_axes = _fetch_axis(attrs.dst_layout)
src_letter_list = []
dst_letter_list = []
for item in attrs.src_layout:
if "A" <= item <= "Z" or "a" <= item <= "z":
src_letter_list.append(item)
for item in attrs.dst_layout:
if "A" <= item <= "Z" or "a" <= item <= "z":
dst_letter_list.append(item)
out_layout_len = len(dst_major_axes) + len(dst_minor_axes)
dst_equal_list = []
dst_mul_list = []
dst_div_list = []
dst_mix_list = []
for key in dst_major_axes:
if key.lower() not in dst_minor_axes:
if key.lower() not in src_minor_axes:
dst_equal_list.append((dst_letter_list.index(key), src_letter_list.index(key)))
else:
dst_mul_list.append(
(
dst_letter_list.index(key),
src_letter_list.index(key),
src_letter_list.index(key.lower()),
)
)
else:
if key.lower() not in src_minor_axes:
dst_div_list.append(
(
dst_letter_list.index(key),
src_letter_list.index(key),
dst_letter_list.index(key.lower()),
dst_minor_axes[key.lower()],
)
)
else:
dst_mix_list.append(
(
dst_letter_list.index(key),
src_letter_list.index(key),
src_minor_axes[key.lower()],
dst_letter_list.index(key.lower()),
dst_minor_axes[key.lower()],
)
)
return [
_layout_transform_shape_func(
inputs[0],
convert(out_layout_len),
convert(dst_equal_list),
convert(dst_mul_list),
convert(dst_div_list),
convert(dst_mix_list),
)
]
@script
def _expand_dim_shape_func(data_shape, ndim, axis, num_newaxis):
out = output_tensor((ndim + num_newaxis,), "int64")
for i in const_range(out.shape[0]):
if i < axis:
out[i] = data_shape[i]
elif i < axis + num_newaxis:
out[i] = int64(1)
else:
out[i] = data_shape[i - num_newaxis]
return out
@_reg.register_shape_func("expand_dims", False)
def expand_dim_shape_func(attrs, inputs, _):
"""
Shape function for expand_dim op.
"""
axis = get_const_int(attrs.axis)
num_newaxis = get_const_int(attrs.num_newaxis)
if axis < 0:
axis = inputs[0].shape[0] + axis + 1
ndim = inputs[0].shape[0] if inputs[0].shape else 0
return [_expand_dim_shape_func(inputs[0], convert(ndim), convert(axis), convert(num_newaxis))]
@script
def _transpose_shape_func(data_shape, axes):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(len(axes)):
out[i] = data_shape[axes[i]]
return out
@_reg.register_shape_func("transpose", False)
def transpose_shape_func(attrs, inputs, _):
"""
Shape function for transpose op.
"""
axes = attrs.axes if attrs.axes is None else get_const_tuple(attrs.axes)
if axes is None:
axes = list(range(inputs[0].shape[0].value))
axes.reverse()
axes = list(axes)
for i, axis in enumerate(axes):
if axis < 0:
axes[i] = inputs[0].shape[0] + axis
return [_transpose_shape_func(inputs[0], convert(axes))]
_reg.register_schedule("transpose", strategy.schedule_transpose)
@script
def _squeeze_shape_func(data_shape, keep_axes, remove_axes):
out = output_tensor((len(keep_axes),), "int64")
for i in const_range(len(keep_axes)):
out[i] = data_shape[keep_axes[i]]
for i in const_range(len(remove_axes)):
assert data_shape[remove_axes[i]] == 1, "Removed dimension must have size 1"
return out
@_reg.register_shape_func("squeeze", False)
def squeeze_shape_func(attrs, inputs, _):
"""
Shape function for squeeze op.
"""
axis = attrs.axis if attrs.axis is None else get_const_tuple(attrs.axis)
keep_axes = []
remove_axes = []
if axis is not None:
for i in range(inputs[0].shape[0].value):
if i not in axis:
keep_axes.append(i)
else:
remove_axes.append(i)
# Due to current relay type system, it is possible even
# a static kernel function needs shape function. To handle
# this case, we allow axis to be None in squeeze shape func
# for now.
# TODO(kevinthesun): Enhance relay type system to avoid this.
if keep_axes:
out = _squeeze_shape_func(inputs[0], convert(keep_axes), convert(remove_axes))
else:
out = te.compute((), lambda *indices: 0)
return [out]
@script
def _reshape_like_shape_func(target_shape):
out = output_tensor((target_shape.shape[0],), "int64")
for i in const_range(target_shape.shape[0]):
out[i] = target_shape[i]
return out
@_reg.register_shape_func("reshape_like", False)
def reshape_like_shape_func(attrs, inputs, _):
"""
Shape function for reshape_like op.
"""
return [_reshape_like_shape_func(inputs[1])]
@script
def _tile_shape_func(data, reps, ndim, tndim, rndim):
out = output_tensor((tndim,), "int64")
if ndim == rndim:
for i in const_range(tndim):
out[i] = data[i] * int64(reps[i])
elif ndim > rndim:
ngap = ndim - rndim
for i in const_range(ndim):
if i < ngap:
out[i] = data[i]
else:
out[i] = data[i] * int64(reps[i - ngap])
else:
rgap = rndim - ndim
for i in const_range(rndim):
if i < rgap:
out[i] = int64(reps[i])
else:
out[i] = int64(reps[i]) * data[i - rgap]
return out
@_reg.register_shape_func("tile", False)
def tile_shape_func(attrs, inputs, _):
"""
Shape function for tile op.
"""
reps = get_const_tuple(attrs.reps)
ndim = inputs[0].shape[0].value
rndim = len(reps)
tndim = ndim if ndim > rndim else rndim
return [
_tile_shape_func(inputs[0], convert(reps), convert(ndim), convert(tndim), convert(rndim))
]
@script
def _split_shape_func(data_shape, index, indices_or_sections, param_is_indices, axis):
out = output_tensor((data_shape.shape[0],), "int64")
if param_is_indices:
for i in const_range(data_shape.shape[0]):
if i == axis:
assert (
data_shape[axis] % indices_or_sections[0] == 0
), "num_sections must be an integer factor of the size of axis"
out[i] = ceil_div(data_shape[axis], indices_or_sections[0])
else:
out[i] = data_shape[i]
else:
start = int64(0)
if index > 0:
start = int64(indices_or_sections[index - 1])
end = data_shape[axis]
if index < len(indices_or_sections):
end = int64(indices_or_sections[index])
for i in const_range(data_shape.shape[0]):
if i == axis:
out[i] = end - start
else:
out[i] = data_shape[i]
return out
@_reg.register_shape_func("split", False)
def split_shape_func(attrs, inputs, _):
"""
Shape function for split op.
"""
if isinstance(attrs.indices_or_sections, (int, tvm.tir.IntImm)):
indices_or_sections = get_const_int(attrs.indices_or_sections)
assert indices_or_sections > 0, "Slice count must be > 0"
else:
indices_or_sections = list(get_const_tuple(attrs.indices_or_sections))
assert sorted(indices_or_sections)[0] > 0 and indices_or_sections == sorted(
indices_or_sections
), "split_indices must be sorted"
axis = get_const_int(attrs.axis)
if axis < 0:
axis += get_const_int(inputs[0].shape[0])
num_out = (
indices_or_sections
if isinstance(indices_or_sections, int)
else len(indices_or_sections) + 1
)
param_is_indices = isinstance(indices_or_sections, int)
if param_is_indices:
indices_or_sections = [indices_or_sections]
return [
_split_shape_func(
inputs[0],
convert(i),
convert(indices_or_sections),
convert(param_is_indices),
convert(axis),
)
for i in range(num_out)
]
@script
def _repeat_shape_func(data_shape, repeats, axis):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(data_shape.shape[0]):
if i == axis:
out[i] = int64(data_shape[i] * repeats)
else:
out[i] = data_shape[i]
return out
@_reg.register_shape_func("repeat", False)
def repeat_shape_func(attrs, inputs, _):
"""
Shape func for repeat.
"""
axis = get_const_int(attrs.axis)
if axis < 0:
axis = inputs[0].shape[0] + axis
return [_repeat_shape_func(inputs[0], attrs.repeats, convert(axis))]
@_reg.register_shape_func("broadcast_to_like", False)
def broadcast_to_like_shape_func(attrs, inputs, _):
"""
Shape func for broadcast_to_like.
"""
return [topi.math.identity(inputs[1])]
@script
def _stack_shape_func(data_shape, axis, num_inputs):
out = output_tensor((data_shape.shape[0] + 1,), "int64")
for i in const_range(data_shape.shape[0] + 1):
if i == axis:
out[i] = int64(num_inputs)
elif i < axis:
out[i] = data_shape[i]
else:
out[i] = data_shape[i - 1]
return out
@_reg.register_shape_func("stack", False)
def stack_shape_func(attrs, inputs, _):
"""
Shape func for stack.
"""
axis = get_const_int(attrs.axis)
if axis < 0:
axis += inputs[0].shape[0] + 1
return [_stack_shape_func(inputs[0], convert(axis), convert(len(inputs)))]
@script
def _broadcast_shape_tensors(shape_tensor1, shape_tensor2):
rank1 = shape_tensor1.shape[0]
rank2 = shape_tensor2.shape[0]
out_rank = max(rank1, rank2)
bcast_shape_tensor = output_tensor((out_rank,), "int64")
for index in const_range(out_rank):
dim1 = int64(1)
dim2 = int64(1)
if rank1 == out_rank:
dim1 = shape_tensor1[index]
elif rank1 - (out_rank - index) >= 0:
dim1 = shape_tensor1[rank1 - (out_rank - index)]
if rank2 == out_rank:
dim2 = shape_tensor2[index]
elif rank2 - (out_rank - index) >= 0:
dim2 = shape_tensor2[rank2 - (out_rank - index)]
assert dim1 == dim2 or dim1 == 1 or dim2 == 1, "Invalid broadcast shapes"
bcast_shape_tensor[index] = max(dim1, dim2)
return bcast_shape_tensor
@_reg.register_shape_func("where", False)
def where_shape_func(attrs, inputs, _):
"""
Shape func for where.
"""
def ensure_tensor(tensor):
if len(tensor.shape) == 0:
return topi.full((1,), "int64", 1)
return tensor
cond_shape = ensure_tensor(inputs[0])
x_shape = ensure_tensor(inputs[1])
y_shape = ensure_tensor(inputs[2])
bcast_shape = _broadcast_shape_tensors(x_shape, y_shape)
out_shape = _broadcast_shape_tensors(bcast_shape, cond_shape)
return [out_shape]
@script
def _adv_index_post_process(data_shape, bcast_shape, num_indices):
data_rank = data_shape.shape[0]
bcast_rank = bcast_shape.shape[0]
out = output_tensor((data_rank + bcast_rank - num_indices,), "int64")
for i in const_range(bcast_rank):
out[i] = bcast_shape[i]
for i in const_range(data_rank - num_indices):
out[i + bcast_rank] = data_shape[i + num_indices]
return out
@_reg.register_shape_func("adv_index", False)
def adv_index_shape_func(attrs, inputs, _):
"""
Shape func for adv_index.
"""
bcast_shape = inputs[1]
for i in inputs[2:]:
bcast_shape = _broadcast_shape_tensors(bcast_shape, i)
return [_adv_index_post_process(inputs[0], bcast_shape, convert(len(inputs) - 1))]
@script
def _unique_shape(data_shape):
unique_shape = output_tensor((1,), "int64")
indices_shape = output_tensor((1,), "int64")
inverse_indices_shape = output_tensor((1,), "int64")
num_unique_shape = output_tensor((1,), "int64")
unique_shape[0] = data_shape[0]
indices_shape[0] = data_shape[0]
inverse_indices_shape[0] = data_shape[0]
num_unique_shape[0] = int64(1)
return (unique_shape, indices_shape, inverse_indices_shape, num_unique_shape)
@script
def _unique_with_counts_shape(data_shape):
unique_shape = output_tensor((1,), "int64")
indices_shape = output_tensor((1,), "int64")
inverse_indices_shape = output_tensor((1,), "int64")
num_unique_shape = output_tensor((1,), "int64")
counts_shape = output_tensor((1,), "int64")
unique_shape[0] = data_shape[0]
indices_shape[0] = data_shape[0]
inverse_indices_shape[0] = data_shape[0]
num_unique_shape[0] = int64(1)
counts_shape[0] = data_shape[0]
return (unique_shape, indices_shape, inverse_indices_shape, num_unique_shape, counts_shape)
@_reg.register_shape_func("unique", False)
def unique_shape_func(attrs, inputs, _):
"""
Shape func for unique operator.
"""
if attrs.return_counts:
return _unique_with_counts_shape(inputs[0])
else:
return _unique_shape(inputs[0])
@script
def _gather_nd_shape(data_shape, indices_shape, batch_dims, index_rank):
ndim = data_shape.shape[0]
# using mdim = indices_shape[0] wouldn't work because a rank cannot
# depend on a runtime shape dimension of indices tensor, even if the
# dimension is always a known, fixed value. As a workaround, we assume that
# the fixed gather dimension (the size of an indexing tuple) is recorded
# in gather_nd op attributes.
mdim = index_rank
kdim = indices_shape.shape[0] - 1
out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64")
for i in range(1, kdim + 1):
out_shape[i - 1] = indices_shape[i]
for i in range(mdim + batch_dims, ndim):
out_shape[kdim + i - (mdim + batch_dims)] = data_shape[i]
return out_shape
@_reg.register_shape_func("gather_nd", False)
def gather_nd_shape_func(attrs, inputs, _):
"""
Shape func for gather_nd operator.
"""
batch_dims = get_const_int(attrs.batch_dims)
index_rank = get_const_int(attrs.index_rank)
assert index_rank > 0, "index_rank needs to be specified for dynamic gather_nd"
return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank))]
@script
def _gather_shape(data_shape, indices_shape, axis):
out_shape = output_tensor((data_shape.shape[0],), "int64")
for i in range(data_shape.shape[0]):
if i != axis:
assert (
data_shape[i] == indices_shape[i]
), "data and indices size at non-gather axes must be the same"
out_shape[i] = indices_shape[i]
return out_shape
@_reg.register_shape_func("gather", False)
def gather_shape_func(attrs, inputs, _):
"""
Shape func for gather operator.
"""
return [_gather_shape(inputs[0], inputs[1], attrs.axis)]