blob: d314bde76c6aa7973d6f57ec0e6aef9915ffc568 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from .. import defop, AllTypes
_bin_logic_op_map = {
'equal': lambda a, b, *idx: a[idx] == b[idx],
'not_equal': lambda a, b, *idx: a[idx] != b[idx],
'greater': lambda a, b, *idx: a[idx] > b[idx],
'less': lambda a, b, *idx: a[idx] < b[idx],
'greater_equal': lambda a, b, *idx: a[idx] >= b[idx],
'less_equal': lambda a, b, *idx: a[idx] <= b[idx],
'logical_and': lambda a, b, *idx: tvm.tir.all(a[idx] != 0, b[idx] != 0),
'logical_or': lambda a, b, *idx: tvm.tir.any(a[idx] != 0, b[idx] != 0),
'logical_xor': lambda a, b, *idx: tvm.tir.all(tvm.tir.any(a[idx] != 0, b[idx] != 0), tvm.tir.any(a[idx] == 0, b[idx] == 0)),
}
def _compute_binary_logic(op, dtype, ndim):
a = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], dtype=dtype, name='a')
b = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], dtype=dtype, name='b')
c = tvm.te.compute([tvm.te.size_var() for _ in range(ndim)],
lambda *idx: _bin_logic_op_map[op](a, b, *idx), name='c')
s = tvm.te.create_schedule(c.op)
return s, a, b, c
_bin_logic_cpu_attrs = {
'compute_func': _compute_binary_logic,
'target': 'cpu',
'auto_broadcast': True,
'itype': AllTypes + ['bool'],
'ndim': list(range(6))
}
_bin_logic_gpu_attrs = {
'compute_func': _compute_binary_logic,
'target': 'gpu',
'auto_broadcast': True,
'itype': AllTypes + ['bool'],
'ndim': list(range(6))
}
def _binary_logic_cpu(compute_func, op, itype, ndim):
s, a, b, c = compute_func(op, itype, ndim)
axes = [axis for axis in c.op.axis]
fused = s[c].fuse(*axes)
s[c].parallel(fused)
return s, [a, b, c]
def _binary_logic_gpu(compute_func, op, itype, ndim):
s, a, b, c = compute_func(op, itype, ndim)
axes = [axis for axis in c.op.axis]
fused = s[c].fuse(*axes)
bx, tx = s[c].split(fused, factor=64)
s[c].bind(bx, tvm.te.thread_axis('blockIdx.x'))
s[c].bind(tx, tvm.te.thread_axis('threadIdx.x'))
return s, [a, b, c]
# register binary element-wise logic ops with broadcasting supported
for op_name in _bin_logic_op_map.keys():
defop(name='{}_cpu'.format(op_name), op=op_name, **_bin_logic_cpu_attrs)(_binary_logic_cpu)
defop(name='{}_gpu'.format(op_name), op=op_name, **_bin_logic_gpu_attrs)(_binary_logic_gpu)
# Note that `b.dtype` is hard-coded as 'float64'.
# We should always promote `a`'s elements to `b.dtype`.
_bin_scalar_logic_op_map = {
'equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) == b,
'not_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) != b,
'greater_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) > b,
'less_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) < b,
'greater_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) >= b,
'less_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) <= b,
'logical_and_scalar': lambda a, b, *idx: tvm.tir.all(a[idx].astype(b.dtype) != 0 , b != 0),
'logical_or_scalar': lambda a, b, *idx: tvm.tir.any(a[idx].astype(b.dtype) != 0, b != 0),
'logical_xor_scalar': lambda a, b, *idx: tvm.tir.all(tvm.tir.any(a[idx].astype(b.dtype) != 0, b != 0), tvm.tir.any(a[idx].astype(b.dtype) == 0, b == 0)),
}
def _compute_binary_scalar_logic(op, dtype, ndim):
a = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='a', dtype=dtype)
b = tvm.te.var('b', dtype='float64')
c = tvm.te.compute([tvm.te.size_var() for _ in range(ndim)],
lambda *idx: _bin_scalar_logic_op_map[op](a, b, *idx), name='c')
s = tvm.te.create_schedule(c.op)
return s, a, b, c
_bin_scalar_logic_cpu_attrs = {
'compute_func': _compute_binary_scalar_logic,
'target': 'cpu',
'itype': AllTypes + ['bool'],
'ndim': list(range(6))
}
_bin_scalar_logic_gpu_attrs = {
'compute_func': _compute_binary_scalar_logic,
'target': 'gpu',
'itype': AllTypes + ['bool'],
'ndim': list(range(6))
}
# register binary element-wise scalar logic ops
for op_name in _bin_scalar_logic_op_map.keys():
defop(name='{}_cpu'.format(op_name), op=op_name,
**_bin_scalar_logic_cpu_attrs)(_binary_logic_cpu)
defop(name='{}_gpu'.format(op_name), op=op_name,
**_bin_scalar_logic_gpu_attrs)(_binary_logic_gpu)