blob: 75c19af35e5c4cbea244e51995aae6a4a6822c3b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Argwhere operator"""
from tvm.te import hybrid
@hybrid.script
def hybrid_argwhere_1d(output_shape, condition):
"""Find the indices of elements of a 1-D tensor that are non-zero.
Parameters
----------
condition : tvm.te.Tensor
1-D tensor with boolean values.
Returns
-------
out : tvm.te.Tensor
Indices of non-zero elements.
"""
a = output_tensor(output_shape, "int32")
a1 = condition.shape[0]
valid_index = 0
for i1 in range(a1):
if condition[i1] != 0:
a[valid_index, 0] = i1
valid_index += 1
return a
@hybrid.script
def hybrid_argwhere_2d(output_shape, condition):
"""Find the indices of elements of a 2-D tensor that are non-zero.
Parameters
----------
condition : tvm.te.Tensor
2-D tensor with boolean values.
Returns
-------
out : tvm.te.Tensor
Indices of non-zero elements.
"""
a = output_tensor(output_shape, "int32")
a1 = condition.shape[0]
a2 = condition.shape[1]
valid_index = 0
for i1 in range(a1):
for i2 in range(a2):
if condition[i1, i2] != 0:
a[valid_index, 0] = i1
a[valid_index, 1] = i2
valid_index += 1
return a
@hybrid.script
def hybrid_argwhere_3d(output_shape, condition):
"""Find the indices of elements of a 3-D tensor that are non-zero.
Parameters
----------
condition : tvm.te.Tensor
3-D tensor with boolean values.
Returns
-------
out : tvm.te.Tensor
Indices of non-zero elements.
"""
a = output_tensor(output_shape, "int32")
a1 = condition.shape[0]
a2 = condition.shape[1]
a3 = condition.shape[2]
valid_index = 0
for i1 in range(a1):
for i2 in range(a2):
for i3 in range(a3):
if condition[i1, i2, i3] != 0:
a[valid_index, 0] = i1
a[valid_index, 1] = i2
a[valid_index, 2] = i3
valid_index += 1
return a
@hybrid.script
def hybrid_argwhere_4d(output_shape, condition):
"""Find the indices of elements of a 4-D tensor that are non-zero.
Parameters
----------
condition : tvm.te.Tensor
4-D tensor with boolean values.
Returns
-------
out : tvm.te.Tensor
Indices of non-zero elements.
"""
a = output_tensor(output_shape, "int32")
a1 = condition.shape[0]
a2 = condition.shape[1]
a3 = condition.shape[2]
a4 = condition.shape[3]
valid_index = 0
for i1 in range(a1):
for i2 in range(a2):
for i3 in range(a3):
for i4 in range(a4):
if condition[i1, i2, i3, i4] != 0:
a[valid_index, 0] = i1
a[valid_index, 1] = i2
a[valid_index, 2] = i3
a[valid_index, 3] = i4
valid_index += 1
return a
@hybrid.script
def hybrid_argwhere_5d(output_shape, condition):
"""Find the indices of elements of a 5-D tensor that are non-zero.
Parameters
----------
condition : tvm.te.Tensor
5-D tensor with boolean values.
Returns
-------
out : tvm.te.Tensor
Indices of non-zero elements.
"""
a = output_tensor(output_shape, "int32")
a1 = condition.shape[0]
a2 = condition.shape[1]
a3 = condition.shape[2]
a4 = condition.shape[3]
a5 = condition.shape[4]
valid_index = 0
for i1 in range(a1):
for i2 in range(a2):
for i3 in range(a3):
for i4 in range(a4):
for i5 in range(a5):
if condition[i1, i2, i3, i4, i5] != 0:
a[valid_index, 0] = i1
a[valid_index, 1] = i2
a[valid_index, 2] = i3
a[valid_index, 3] = i4
a[valid_index, 4] = i5
valid_index += 1
return a
def argwhere(output_shape, condition):
"""Find the indices of elements of a tensor that are non-zero.
Parameters
----------
condition : tvm.te.Tensor
Tensor with boolean values.
Returns
-------
out : tvm.te.Tensor
Indices of non-zero elements.
"""
if len(condition.shape) == 1:
return hybrid_argwhere_1d(output_shape.shape, condition)
if len(condition.shape) == 2:
return hybrid_argwhere_2d(output_shape.shape, condition)
if len(condition.shape) == 3:
return hybrid_argwhere_3d(output_shape.shape, condition)
if len(condition.shape) == 4:
return hybrid_argwhere_4d(output_shape.shape, condition)
if len(condition.shape) == 5:
return hybrid_argwhere_5d(output_shape.shape, condition)
raise ValueError("Does not support rank higher than 5 in argwhere")