blob: 661e24d2a45cf314ca6bac6445ee6f392ad593d6 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Implementation of generic operators in the presence of Tensor"""
# pylint: disable=invalid-name, too-many-arguments
import tvm
from tvm import te
from . import broadcast as _broadcast
from . import math as _math
def _make_bop(broadcast_bop, orig_bop):
"""Make a specific overloaded binary operator of Tensor when applicable;
apply the original operator if it is not supposed to be overloaded.
Consider the following scenario:
OP : + | - | * | /
R0 : int | float | Expr | TensorSlice | Tensor (rank zero)
R1 : Tensor (positive rank)
In terms of (LHS OP RHS), we apply the following overloading rules:
(1) We use broadcast_OP(LHS, RHS), when both LHS and RHS are R1.
(2) We perform element-wise operation of Tensor and scalar,
when one of LHS and RHS is R1 and another is R0.
(3) We do not overload OP (i.e. stick to orig_bop) otherwise.
Parameters
----------
broadcast_bop : operator function
Operator for broadcast tensor-tensor operation, for rule (1).
orig_bop: operator function
Operator before overloading, for rule (3).
Returns
-------
ret : operator function
The overloaded operator function if applicable or orig_bop otherwise.
"""
name = orig_bop.__name__
def _tensor_bop_impl(lhs, rhs):
"""Overloaded {op} operator.
If both operands are non-zero-rank Tensors, it performs
tensor-tensor {op} operation, and broadcasts inputs when necessary.
If one operand is non-zero-rank Tensor, while the other operand is
scalar like type (e.g., numeric types, Expr, or TensorSlice),
it performs tensor-scalar {op} operation on an element-wise basis.
Otherwise, it performs default generic.{op} operation, as defined
in tvm.generic module.
Parameters
----------
lhs : object
Left operand.
rhs : object
Right operand.
Returns
-------
ret : tvm.te.Tensor (if at least one operand is non-zero-rank Tensor)
tvm.Expr (otherwise)
The result of {op} operation.
"""
if not isinstance(lhs, te.tensor.Tensor) and not isinstance(rhs, te.tensor.Tensor):
return orig_bop(lhs, rhs)
return broadcast_bop(lhs, rhs)
_tensor_bop_impl.__doc__ = _tensor_bop_impl.__doc__.format(op=name)
return _tensor_bop_impl
def _bind_generic_ops():
"""Bind generic operators for Tensor."""
# Check __op_priority__ to make sure the binding happens only once.
__op_priority__ = 1
if __op_priority__ > tvm.tir.generic.__op_priority__:
tvm.tir.generic.__op_priority__ = __op_priority__
tvm.tir.generic.add = _make_bop(_broadcast.add, tvm.tir.generic.add)
tvm.tir.generic.subtract = _make_bop(_broadcast.subtract, tvm.tir.generic.subtract)
tvm.tir.generic.multiply = _make_bop(_broadcast.multiply, tvm.tir.generic.multiply)
tvm.tir.generic.divide = _make_bop(_broadcast.divide, tvm.tir.generic.divide)
tvm.tir.generic.cast = _math.cast
_bind_generic_ops()