blob: 3460ffec90ca77b7138cb159fc82f13ece2af859 [file] [log] [blame]
# pylint: disable=invalid-name
"""Dilation operators"""
from __future__ import absolute_import as _abs
import tvm
from .. import util
from .. import tag
@tvm.tag_scope(tag=tag.INJECTIVE+",dilate")
def dilate(data, strides, name="DilatedInput"):
"""Dilate data with zeros.
Parameters
----------
data : tvm.Tensor
n-D, can be any layout.
strides : list / tuple of n ints
Dilation stride on each dimension, 1 means no dilation.
name : str, optional
The name prefix operators generated
Returns
-------
Output : tvm.Tensor
n-D, the same layout as data.
"""
n = len(data.shape)
if len(strides) != n:
raise ValueError("data dimension and strides size dismatch : %d vs %d" % (
n, len(strides)))
out_shape = tuple(
tvm.ir_pass.Simplify((data.shape[i] - 1) * strides[i] + 1) for i in range(n))
def _dilate(*indices):
not_zero = []
index_tuple = []
for i in range(n):
if not util.equal_const_int(strides[i], 1):
index_tuple.append(indices[i] / strides[i])
not_zero.append((indices[i] % strides[i]).equal(0))
else:
index_tuple.append(indices[i])
if not_zero:
not_zero = tvm.all(*not_zero)
return tvm.if_then_else(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
return data(*index_tuple)
return tvm.compute(out_shape, _dilate, name=name)