| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=redefined-builtin, invalid-name |
| """Operators used in TIR expression.""" |
| import warnings |
| from typing import Any, Optional |
| import tvm._ffi |
| from tvm.ir.base import Span |
| from tvm.runtime import convert, const |
| from tvm.ir import Array, Op |
| |
| from .buffer import Buffer |
| from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer |
| from . import _ffi_api |
| |
| |
| def _pack_buffer(buf, span=None): |
| """Build intrinsics that packs the buffer.""" |
| shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape, span) |
| strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides, span) if buf.strides else 0 |
| pack_args = [ |
| buf.data, |
| shape, |
| strides, |
| len(buf.shape), |
| const(0, dtype=buf.dtype), |
| buf.elem_offset, |
| ] |
| return Call("handle", Op.get("tir.tvm_stack_make_array"), pack_args, span) |
| |
| |
| def call_packed_lowered(*args, span=None): |
| """Lowered version of call packed. |
| The argument to packed function can be Expr or Buffer. |
| The argument is the corresponding POD type when Expr is presented. |
| When the argument is Buffer, the corresponding PackedFunc |
| will recieve an TVMArrayHandle whose content is valid during the callback period. |
| If the PackedFunc is a python callback, then the corresponding argument is NDArray. |
| |
| Parameters |
| ---------- |
| args : list of Expr or Buffer. |
| Positional arguments. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| |
| See Also |
| -------- |
| te.extern : Create tensor with extern function call. |
| """ |
| call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] |
| return Call("int32", Op.get("tir.tvm_call_packed_lowered"), call_args, span) |
| |
| |
| def call_cpacked_lowered(*args, span=None): |
| """Lowered version of call c-packed. |
| Same as call_packed, except that the first argument is the function name |
| (as in call_extern), and the last argument is the resource handle. |
| |
| Parameters |
| ---------- |
| args : list of Expr or Buffer. |
| Positional arguments. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| |
| See Also |
| -------- |
| te.extern : Create tensor with extern function call. |
| """ |
| call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] |
| return Call("int32", Op.get("tir.tvm_call_cpacked_lowered"), call_args, span) |
| |
| |
| def call_packed(*args, span=None): |
| """Build expression by call an external packed function. |
| |
| The argument to packed function can be Expr or Buffer. |
| The argument is the corresponding POD type when Expr is presented. |
| |
| When the argument is Buffer, the corresponding PackedFunc |
| will recieve an TVMArrayHandle whose content is valid during the callback period. |
| If the PackedFunc is a python callback, then the corresponding argument is NDArray. |
| |
| Parameters |
| ---------- |
| args : list of Expr or Buffer. |
| Positional arguments. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| |
| See Also |
| -------- |
| te.extern : Create tensor with extern function call. |
| """ |
| call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] |
| return Call("int32", Op.get("tir.tvm_call_packed"), call_args, span) |
| |
| |
| def call_cpacked(*args, span=None): |
| """Build expression by call an external packed function. |
| |
| Same as call_packed, except that the first argument is the function name |
| (as in call_extern), and the last argument is the resource handle. |
| |
| Parameters |
| ---------- |
| args : list of Expr or Buffer. |
| Positional arguments. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| |
| See Also |
| -------- |
| te.extern : Create tensor with extern function call. |
| """ |
| call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] |
| return Call("int32", Op.get("tir.tvm_call_cpacked"), call_args, span) |
| |
| |
| def call_intrin(dtype, func_name, *args, span=None): |
| """Build expression by calling an intrinsic function. |
| |
| Intrinsics can be overloaded with multiple data types via |
| the intrinsic translation rule. |
| |
| Parameters |
| ---------- |
| dtype : str |
| The data type of the result. |
| |
| func_name: str |
| The intrinsic function name. |
| |
| args : list |
| Positional arguments. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return Call(dtype, func_name, convert(args), span) |
| |
| |
| def call_pure_extern(dtype, func_name, *args, span=None): |
| """Build expression by calling a pure extern function. |
| |
| Parameters |
| ---------- |
| dtype : str |
| The data type of the result. |
| |
| func_name: str |
| The extern function name. |
| |
| args : list |
| Positional arguments. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return Call( |
| dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args), span |
| ) |
| |
| |
| def call_extern(dtype, func_name, *args, span=None): |
| """Build expression by calling a extern function. |
| |
| Parameters |
| ---------- |
| dtype : str |
| The data type of the result. |
| |
| func_name: str |
| The extern function name. |
| |
| args : list |
| Positional arguments. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return Call( |
| dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), span=span |
| ) |
| |
| |
| def call_llvm_intrin(dtype, name, *args, span=None): |
| """Build expression by calling a llvm intrinsic function |
| |
| Parameters |
| ---------- |
| dtype : str |
| The data type of the result. |
| |
| name : str |
| The name of the llvm intrinsic function. |
| |
| args : list |
| Poistional arguments. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| # pylint: disable=import-outside-toplevel |
| from tvm.target import codegen |
| |
| from .expr import IntImm |
| |
| if isinstance(name, str): |
| llvm_id = codegen.llvm_lookup_intrinsic_id(name) |
| elif isinstance(name, IntImm): |
| llvm_id = name.value |
| else: |
| llvm_id = name |
| if llvm_id == 0: |
| warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0") |
| return call_intrin( |
| dtype, |
| Op.get("tir.call_llvm_intrin"), |
| tvm.tir.const(llvm_id, "uint32"), |
| *args, |
| span=span, |
| ) |
| |
| |
| def call_llvm_pure_intrin(dtype, name, *args, span=None): |
| """Build expression by calling a pure llvm intrinsic function |
| |
| Parameters |
| ---------- |
| dtype : str |
| The data type of the result. |
| |
| name : str |
| The name of the llvm intrinsic function. |
| |
| args : list |
| Poistional arguments. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| # pylint: disable=import-outside-toplevel |
| from tvm.target import codegen |
| |
| from .expr import IntImm |
| |
| if isinstance(name, str): |
| llvm_id = codegen.llvm_lookup_intrinsic_id(name) |
| elif isinstance(name, IntImm): |
| llvm_id = name.value |
| else: |
| llvm_id = name |
| if llvm_id == 0: |
| warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0") |
| return call_intrin( |
| dtype, |
| Op.get("tir.call_llvm_pure_intrin"), |
| tvm.tir.const(llvm_id, "uint32"), |
| *args, |
| span=span, |
| ) |
| |
| |
| def tvm_stack_alloca(dtype_str, num): |
| """Return new on stack dtype[num] |
| |
| Parameters |
| ---------- |
| dtype_str : str |
| The data type of array. |
| |
| num : int |
| The size of array. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return call_intrin("handle", "tir.tvm_stack_alloca", dtype_str, num) |
| |
| |
| def tvm_stack_make_shape(*args): |
| """Allocate a shape tuple on stack, return the handle |
| |
| Parameters |
| ---------- |
| args : int |
| The tuple shape. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return call_intrin("handle", "tir.tvm_stack_make_shape", *args) |
| |
| |
| def tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset): |
| """Allocate a NDArray(DLTensor) on stack, return the handle |
| |
| Parameters |
| ---------- |
| data : Expr |
| The data of array. |
| |
| shape : Expr |
| The shape of array. |
| |
| strides : Expr |
| The strides of array. |
| |
| ndim : Expr |
| The dimensions of array. |
| |
| arr_dtype : Expr |
| The data type of array. |
| |
| elem_offse : Expr |
| The element offset of array. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return call_intrin( |
| "handle", "tir.tvm_stack_make_array", data, shape, strides, ndim, arr_dtype, elem_offset |
| ) |
| |
| |
| def assume(cond=None): |
| """Provide a true statement that can be used for simplifications |
| |
| Parameters |
| ---------- |
| cond : Expr |
| The constraint condition. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return call_intrin("int32", "tir.assume", cond) |
| |
| |
| def undef(): |
| """Returns an initialized but arbitrary value |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return call_intrin("int32", "tir.undef") |
| |
| |
| def tvm_tuple(*value): |
| """Create a tuple structure in value field of AttrStmt |
| |
| Parameters |
| ---------- |
| value : Expr |
| The value in tuple. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return call_intrin("handle", "tir.tvm_tuple", *value) |
| |
| |
| def tvm_struct_get(arr, index, field, dtype): |
| """Get struct field value in array |
| |
| Parameters |
| ---------- |
| dtype : str |
| The date type of the result. |
| |
| arr : StructType* |
| The array of struct. |
| |
| index : int |
| The index of struct. |
| |
| field : int |
| The field of struct. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return call_intrin(dtype, "tir.tvm_struct_get", arr, index, field) |
| |
| |
| def tvm_struct_set(arr, index, field, value): |
| """Set value in struct field in array |
| |
| Parameters |
| ---------- |
| arr : StructType* |
| The array of struct. |
| |
| index : int |
| The index of struct. |
| |
| field : int |
| The field of struct. |
| |
| value : Expr |
| The value to be set in field. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return call_intrin("handle", "tir.tvm_struct_set", arr, index, field, value) |
| |
| |
| def address_of(buffer_load, span=None): |
| """Returns the address of an element in the buffer |
| |
| Parameters |
| ---------- |
| buffer_load: BufferLoad |
| The buffer load. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return call_intrin("handle", "tir.address_of", buffer_load, span=span) |
| |
| |
| def lookup_param(param_name, span=None): |
| """Returns the param by name |
| |
| Parameters |
| ---------- |
| param_name : str |
| The name of param. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return call_intrin("handle", "tir.lookup_param", param_name, span=span) |
| |
| |
| def tvm_thread_allreduce(*freduce_args): |
| """ |
| Parameters |
| ---------- |
| freduce_args : Expr |
| The args. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args) |
| |
| |
| def type_annotation(dtype): |
| """Create a type annotation expression |
| |
| Parameters |
| ---------- |
| dtype : Expr |
| The data type. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return call_intrin(dtype, "tir.type_annotation") |
| |
| |
| def tvm_access_ptr(ptype, data, offset, extent, rw_mask): |
| """Get head access address with memory access pattern info |
| |
| Parameters |
| ---------- |
| ptype : Expr |
| The data type of pointer. |
| |
| data : DType* |
| The data of pointer. |
| |
| offset : int |
| The offset of pointer. |
| |
| extent : int |
| The extent of pointer. |
| |
| rw_mask : int |
| The read write mask. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return call_intrin("handle", "tir.tvm_access_ptr", ptype, data, offset, extent, rw_mask) |
| |
| |
| def tvm_throw_last_error(): |
| """Throw TVMGetLastError() |
| |
| Returns |
| ------- |
| ret : PrimExpr |
| The return expression |
| """ |
| return call_intrin("handle", "tir.tvm_throw_last_error") |
| |
| |
| def ret(val): |
| """Create a tir return expression |
| |
| Parameters |
| ---------- |
| val : Expr |
| The returned tir expression, whose data type is int, float or void pointer. |
| |
| Returns |
| ------- |
| ret : PrimExpr |
| The return expression |
| """ |
| return call_intrin(val.dtype, "tir.ret", val) |
| |
| |
| def any(*args, span=None): |
| """Create a new experssion of the union of all conditions in the arguments |
| |
| Parameters |
| ---------- |
| args : list |
| List of symbolic boolean expressions |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| expr: Expr |
| Expression |
| """ |
| if not args: |
| raise ValueError("Any must take at least 1 argument") |
| if len(args) == 1: |
| return args[0] |
| val = _ffi_api._OpOr(args[0], args[1], span) # type: ignore |
| for i in range(2, len(args)): |
| val = _ffi_api._OpOr(val, args[i], span) # type: ignore |
| return val |
| |
| |
| def all(*args, span=None): |
| """Create a new expression of the intersection of all conditions in the |
| arguments |
| |
| Parameters |
| ---------- |
| args : list |
| List of symbolic boolean expressions |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| expr: Expr |
| Expression |
| """ |
| if not args: |
| raise ValueError("Any must take at least 1 argument") |
| if len(args) == 1: |
| return args[0] |
| val = _ffi_api._OpAnd(args[0], args[1], span) # type: ignore |
| for i in range(2, len(args)): |
| val = _ffi_api._OpAnd(val, args[i], span) # type: ignore |
| return val |
| |
| |
| @tvm._ffi.register_func("tvm.default_trace_action") |
| def _tvm_default_trace_action(*args): |
| print(list(args)) |
| |
| |
| def trace(args, trace_action="tvm.default_trace_action"): |
| """Trace tensor data at the runtime. |
| |
| The trace function allows to trace specific tensor at the |
| runtime. The tracing value should come as last argument. |
| The trace action should be specified, by default |
| tvm.default_trace_action is used. |
| |
| Parameters |
| ---------- |
| args : list of Expr or Buffers. |
| Positional arguments. |
| |
| trace_action : str. |
| The name of the trace action. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| |
| See Also |
| -------- |
| tvm.tir.call_packed : Creates packed function. |
| """ |
| if not isinstance(args, list): |
| raise Exception("tvm.tir.trace consumes the args as list type") |
| call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] |
| call_args.insert(0, trace_action) |
| return tvm.tir.Call(args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args) |
| |
| |
| def min_value(dtype, span=None): |
| """minimum value of dtype |
| |
| Parameters |
| ---------- |
| dtype : str |
| The data type. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| value : tvm.Expr |
| The minimum value of dtype. |
| """ |
| return _ffi_api.min_value(dtype, span) # type: ignore |
| |
| |
| def max_value(dtype: str, span: Optional[Span] = None) -> Any: |
| """maximum value of dtype |
| |
| Parameters |
| ---------- |
| dtype : str |
| The data type. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| value : tvm.Expr |
| The maximum value of dtype. |
| """ |
| return _ffi_api.max_value(dtype, span) # type: ignore |
| |
| |
| def infinity(dtype: str, span: Optional[Span] = None) -> Any: |
| """infinity value of dtype |
| |
| Parameters |
| ---------- |
| dtype : str |
| The data type. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| value : tvm.Expr |
| The infinity value of dtype. |
| """ |
| return _ffi_api.infinity(dtype, span) # type: ignore |
| |
| |
| def reinterpret(dtype, value) -> Any: |
| """infinity value of dtype |
| |
| Parameters |
| ---------- |
| dtype : str |
| The data type. |
| |
| value : PrimExpr |
| The input value. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| value : tvm.Expr |
| The reinterpret cast value of dtype. |
| """ |
| return call_intrin(dtype, "tir.reinterpret", value) |
| |
| |
| def exp(x): |
| """Take exponential of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.exp", x) |
| |
| |
| def exp2(x): |
| """Calculate 2**x |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.exp2", x) |
| |
| |
| def exp10(x): |
| """Calculate 10**x |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.exp10", x) |
| |
| |
| def erf(x): |
| """Take gauss error function of the input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.erf", x) |
| |
| |
| def tanh(x): |
| """Take hyperbolic tanh of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.tanh", x) |
| |
| |
| def sigmoid(x): |
| """Quick function to get sigmoid |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.sigmoid", x) |
| |
| |
| def log(x): |
| """Take log of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.log", x) |
| |
| |
| def log2(x): |
| """Take log2 of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.log2", x) |
| |
| |
| def log10(x): |
| """Take log10 of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.log10", x) |
| |
| |
| def log1p(x): |
| """Take log(x + 1) with respect to input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.log1p", x) |
| |
| |
| def tan(x): |
| """Take tan of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.tan", x) |
| |
| |
| def cos(x): |
| """Take cos of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.cos", x) |
| |
| |
| def cosh(x): |
| """Take cosh of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.cosh", x) |
| |
| |
| def acos(x): |
| """Take acos of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.acos", x) |
| |
| |
| def acosh(x): |
| """Take acos of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.acosh", x) |
| |
| |
| def sin(x): |
| """Take sin of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.sin", x) |
| |
| |
| def sinh(x): |
| """Take sinh of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.sinh", x) |
| |
| |
| def asin(x): |
| """Take asin of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.asin", x) |
| |
| |
| def asinh(x): |
| """Take asinh of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.asinh", x) |
| |
| |
| def atan(x): |
| """Take atan of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.atan", x) |
| |
| |
| def atanh(x): |
| """Take atanh of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.atanh", x) |
| |
| |
| def atan2(x1, x2): |
| """Take arctan2(x1, x2). |
| |
| Parameters |
| ---------- |
| x1 : PrimExpr |
| Input argument. |
| |
| x2 : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x1.dtype, "tir.atan2", x1, x2) |
| |
| |
| def sqrt(x): |
| """Take square root of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.sqrt", x) |
| |
| |
| def rsqrt(x): |
| """Take reciprocal of square root of input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.rsqrt", x) |
| |
| |
| def clz(x): |
| """Count leading zero bits of an integer x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input 32 or 64 bit integer. |
| The result is undefined if the input is 0. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin("int32", "tir.clz", x) |
| |
| |
| def floor(x: PrimExprWithOp, span=None): |
| """Take floor of float input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return _ffi_api.floor(x, span) # type: ignore |
| |
| |
| def ceil(x, span=None): |
| """Take ceil of float input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return _ffi_api.ceil(x, span) # type: ignore |
| |
| |
| def trunc(x, span=None): |
| """Get truncated value of the input. |
| |
| The truncated value of the scalar x is the |
| nearest integer i which is closer to zero than x is. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return _ffi_api.trunc(x, span) # type: ignore |
| |
| |
| def abs(x, span=None): |
| """Get absolute value of the input element-wise. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return _ffi_api.abs(x, span) # type: ignore |
| |
| |
| def round(x, span=None): |
| """Round elements of the array to the nearest integer. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return _ffi_api.round(x, span) # type: ignore |
| |
| |
| def nearbyint(x, span=None): |
| """Round elements of the array to the nearest integer. |
| This intrinsic uses llvm.nearbyint instead of llvm.round |
| which is faster but will results different from te.round. |
| Notably nearbyint rounds according to the rounding mode, |
| whereas te.round (llvm.round) ignores that. |
| For differences between the two see: |
| https://en.cppreference.com/w/cpp/numeric/math/round |
| https://en.cppreference.com/w/cpp/numeric/math/nearbyint |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return _ffi_api.nearbyint(x, span) # type: ignore |
| |
| |
| def nextafter(x1, x2): |
| """Return the next floating-point value after x1 towards x2. |
| |
| Parameters |
| ---------- |
| x1 : PrimExpr |
| Input argument. |
| |
| x2 : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x1.dtype, "tir.nextafter", x1, x2) # type: ignore |
| |
| |
| def hypot(x1, x2): |
| """Equivalent to sqrt(x1**2 + x2**2), element-wise. |
| |
| Parameters |
| ---------- |
| x1 : PrimExpr |
| Input argument. |
| |
| x2 : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x1.dtype, "tir.hypot", x1, x2) # type: ignore |
| |
| |
| def copysign(x1, x2): |
| """Change the sign of x1 to that of x2, element-wise. |
| |
| Parameters |
| ---------- |
| x1 : PrimExpr |
| Input argument. |
| |
| x2 : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x1.dtype, "tir.copysign", x1, x2) # type: ignore |
| |
| |
| def ldexp(x1, x2): |
| """Returns x1 * (2 ** x2). |
| |
| Parameters |
| ---------- |
| x1 : PrimExpr |
| Input argument. |
| |
| x2 : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore |
| |
| |
| def likely(cond, span=None): |
| """Mark condition as likely. |
| |
| Parameters |
| ---------- |
| |
| cond : PrimExpr |
| Input argument. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The marked expression. |
| """ |
| return _ffi_api.likely(cond, span) # type: ignore |
| |
| |
| def isnan(x, span=None): |
| """Check if input value is Nan. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return _ffi_api.isnan(x, span) # type: ignore |
| |
| |
| def isnullptr(x, span=None): |
| """Check if input value is nullptr. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin("bool", "tir.isnullptr", x, span=span) # type: ignore |
| |
| |
| def isfinite(x, span=None): |
| """Check if input value is finite. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return _ffi_api.isfinite(x, span) # type: ignore |
| |
| |
| def isinf(x, span=None): |
| """Check if input value is infinite. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return _ffi_api.isinf(x, span) # type: ignore |
| |
| |
| def power(x, y, span=None): |
| """x power y |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| y : PrimExpr |
| The exponent |
| |
| span : Optional[Span] |
| The location of this operator in the source code. |
| |
| Returns |
| ------- |
| z : PrimExpr |
| The result. |
| """ |
| return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore |
| |
| |
| def popcount(x): |
| """Count the number of set bits in input x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.popcount", x) |
| |
| |
| def q_multiply_shift(x, y, q, s): |
| """Execute a multiplication between two Q-numbers x and y |
| followed by a right shift s. The mathematical expression is: |
| |
| out = round(x*y*2^-s) |
| |
| More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) |
| The rounding rule is to the nearest value, rounding half up |
| (i.e., round(x.1) = x and round (x.5) = x+1) |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| First Q-number |
| y : PrimExpr |
| Second Q-number |
| q : PrimExpr |
| Number of fractional bits in x and y. Needs to be > 0 |
| s : PrimExpr |
| Integer shift |
| |
| Returns |
| ------- |
| y : PrimExpr |
| The result. |
| """ |
| return call_intrin("int32", "tir.q_multiply_shift", x, y, q, s) |
| |
| |
| def fmod(x, y): |
| """Return the remainder of x divided by y with the same sign as x. |
| |
| Parameters |
| ---------- |
| x : PrimExpr |
| Input argument. |
| y : PrimExpr |
| Input argument. |
| |
| Returns |
| ------- |
| z : PrimExpr |
| The result. |
| """ |
| return call_intrin(x.dtype, "tir.fmod", x, y) |
| |
| |
| def if_then_else(cond, t, f, span=None): |
| """Conditional selection expression. |
| |
| Parameters |
| ---------- |
| cond : PrimExpr |
| The condition |
| |
| t : PrimExpr |
| The result expression if cond is true. |
| |
| f : PrimExpr |
| The result expression if cond is false. |
| |
| span : Optional[Span] |
| The location of this operator in the source. |
| |
| Returns |
| ------- |
| result : Node |
| The result of conditional expression. |
| |
| Note |
| ---- |
| Unlike Select, if_then_else will not execute |
| the branch that does not satisfy the condition. |
| You can use it to guard against out of bound access. |
| Unlike Select, if_then_else cannot be vectorized |
| if some lanes in the vector have different conditions. |
| """ |
| return _ffi_api._OpIfThenElse(convert(cond), convert(t), convert(f), span) # type: ignore |
| |
| |
| def div(a, b, span=None): |
| """Compute a / b as in C/C++ semantics. |
| |
| Parameters |
| ---------- |
| a : PrimExpr |
| The left hand operand, known to be non-negative. |
| |
| b : PrimExpr |
| The right hand operand, known to be non-negative. |
| |
| span : Optional[Span] |
| The location of this operator in the source. |
| |
| Returns |
| ------- |
| res : PrimExpr |
| The result expression. |
| Note |
| ---- |
| When operands are integers, returns truncdiv(a, b, span). |
| """ |
| return _ffi_api._OpDiv(a, b, span) # type: ignore |
| |
| |
| def indexdiv(a, b, span=None): |
| """Compute floor(a / b) where a and b are non-negative. |
| |
| Parameters |
| ---------- |
| a : PrimExpr |
| The left hand operand, known to be non-negative. |
| |
| b : PrimExpr |
| The right hand operand, known to be non-negative. |
| |
| span : Optional[Span] |
| The location of this operator in the source. |
| |
| Returns |
| ------- |
| res : PrimExpr |
| The result expression. |
| |
| Note |
| ---- |
| Use this function to split non-negative indices. |
| This function may take advantage of operands' |
| non-negativeness. |
| """ |
| return _ffi_api._OpIndexDiv(a, b, span) # type: ignore |
| |
| |
| def indexmod(a, b, span=None): |
| """Compute the remainder of indexdiv. a and b are non-negative. |
| |
| Parameters |
| ---------- |
| a : PrimExpr |
| The left hand operand, known to be non-negative. |
| |
| b : PrimExpr |
| The right hand operand, known to be non-negative. |
| |
| span : Optional[Span] |
| The location of this operator in the source. |
| |
| Returns |
| ------- |
| res : PrimExpr |
| The result expression. |
| |
| Note |
| ---- |
| Use this function to split non-negative indices. |
| This function may take advantage of operands' |
| non-negativeness. |
| """ |
| return _ffi_api._OpIndexMod(a, b, span) # type: ignore |
| |
| |
| def truncdiv(a, b, span=None): |
| """Compute the truncdiv of two expressions. |
| |
| Parameters |
| ---------- |
| a : PrimExpr |
| The left hand operand |
| |
| b : PrimExpr |
| The right hand operand |
| |
| span : Optional[Span] |
| The location of this operator in the source. |
| |
| Returns |
| ------- |
| res : PrimExpr |
| The result expression. |
| |
| Note |
| ---- |
| This is the default integer division behavior in C. |
| """ |
| return _ffi_api._OpTruncDiv(a, b, span) # type: ignore |
| |
| |
| def truncmod(a, b, span=None): |
| """Compute the truncmod of two expressions. |
| |
| Parameters |
| ---------- |
| a : PrimExpr |
| The left hand operand |
| |
| b : PrimExpr |
| The right hand operand |
| |
| span : Optional[Span] |
| The location of this operator in the source. |
| |
| Returns |
| ------- |
| res : PrimExpr |
| The result expression. |
| |
| Note |
| ---- |
| This is the default integer division behavior in C. |
| """ |
| return _ffi_api._OpTruncMod(a, b, span) # type: ignore |
| |
| |
| def floordiv(a, b, span=None): |
| """Compute the floordiv of two expressions. |
| |
| Parameters |
| ---------- |
| a : PrimExpr |
| The left hand operand |
| |
| b : PrimExpr |
| The right hand operand |
| |
| span : Optional[Span] |
| The location of this operator in the source. |
| |
| Returns |
| ------- |
| res : PrimExpr |
| The result expression. |
| """ |
| return _ffi_api._OpFloorDiv(a, b, span) # type: ignore |
| |
| |
| def floormod(a, b, span=None): |
| """Compute the floormod of two expressions. |
| |
| Parameters |
| ---------- |
| a : PrimExpr |
| The left hand operand |
| |
| b : PrimExpr |
| The right hand operand |
| |
| span : Optional[Span] |
| The location of this operator in the source. |
| |
| Returns |
| ------- |
| res : PrimExpr |
| The result expression. |
| """ |
| return _ffi_api._OpFloorMod(a, b, span) # type: ignore |
| |
| |
| def ceildiv(lhs, rhs, span=None): |
| """Generic ceildiv operator. |
| |
| Parameters |
| ---------- |
| lhs : object |
| The left operand. |
| rhs : object |
| The right operand. |
| span : Optional[Span] |
| The location of this operator in the source. |
| |
| Returns |
| ------- |
| op : tvm.Expr |
| The result Expr of ceildiv operaton. |
| """ |
| return _ffi_api._OpCeilDiv(lhs, rhs, span) # type: ignore |
| |
| |
| def comm_reducer(fcombine, fidentity, name="reduce"): |
| """Create a commutative reducer for reduction. |
| |
| Parameters |
| ---------- |
| fcombine : function(Expr -> Expr -> Expr) |
| A binary function which takes two Expr as input to return a Expr. |
| |
| fidentity : function(str -> Expr) |
| A function which takes a type string as input to return a const Expr. |
| |
| Returns |
| ------- |
| reducer : function |
| A function which creates a reduce expression over axis. |
| There are two ways to use it: |
| |
| 1. accept (expr, axis, where) to produce an Reduce Expr on |
| specified axis; |
| 2. simply use it with multiple Exprs. |
| |
| Example |
| ------- |
| .. code-block:: python |
| |
| n = te.var("n") |
| m = te.var("m") |
| mysum = te.comm_reducer(lambda x, y: x+y, |
| lambda t: tvm.tir.const(0, dtype=t), name="mysum") |
| A = te.placeholder((n, m), name="A") |
| k = te.reduce_axis((0, m), name="k") |
| B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B") |
| """ |
| |
| def _reduce_directly(*args): |
| num = len(args) |
| # process `where` is None |
| if num == 3 and args[2] is None: |
| num = 2 |
| res = args[0] |
| for i in range(num - 1): |
| res = fcombine(res, args[i + 1]) |
| return res |
| |
| def _make_reduce(expr, axis, where=None, init=None): |
| code = fcombine.__code__ |
| assert fcombine.__code__.co_argcount == 2 |
| expr = convert(expr) |
| if init is not None: |
| init = convert(init) |
| if isinstance(expr, Array): |
| size = len(expr) |
| larr = [] |
| rarr = [] |
| dtypes = [] |
| for i in range(size): |
| dtype = expr[i].dtype |
| dtypes.append(dtype) |
| lname = code.co_varnames[0] + "_" + str(i) |
| larr.append(Var(lname, dtype)) |
| rname = code.co_varnames[1] + "_" + str(i) |
| rarr.append(Var(rname, dtype)) |
| if init is not None: |
| init = convert(init) |
| assert isinstance(init, Array) |
| assert len(init) == size |
| for init_i in range(size): |
| init_i = convert(init_i) |
| assert isinstance( |
| init_i, (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm) |
| ) |
| else: |
| init = convert([]) |
| lhs = convert(larr) |
| rhs = convert(rarr) |
| result = fcombine(lhs, rhs) |
| id_elem = fidentity(*dtypes) |
| else: |
| assert isinstance(expr, tvm.ir.PrimExpr) |
| size = 1 |
| dtype = expr.dtype |
| lvar = Var(code.co_varnames[0], dtype) |
| rvar = Var(code.co_varnames[1], dtype) |
| result = [fcombine(lvar, rvar)] |
| id_elem = [fidentity(dtype)] |
| lhs = convert([lvar]) |
| rhs = convert([rvar]) |
| expr = convert([expr]) |
| if init is not None: |
| assert isinstance(init, (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm)) |
| init = convert([init]) |
| result = convert(result) |
| id_elem = convert(id_elem) |
| combiner = CommReducer(lhs, rhs, result, id_elem) |
| axis = convert(axis if isinstance(axis, (list, tuple)) else [axis]) |
| if where is None: |
| where = convert(True) |
| if init is None: |
| outputs = tuple( |
| tvm.tir.Reduce(combiner, expr, axis, where, i, convert([])) for i in range(size) |
| ) |
| else: |
| outputs = tuple( |
| tvm.tir.Reduce(combiner, expr, axis, where, i, init) for i in range(size) |
| ) |
| return outputs[0] if size == 1 else outputs |
| |
| # pylint: disable=keyword-arg-before-vararg |
| def reducer(expr, axis, where=None, init=None, *args): |
| if isinstance(axis, (tvm.tir.IterVar, list, tuple)): |
| assert not args |
| return _make_reduce(expr, axis, where, init) |
| if where is None: |
| assert not args |
| return _reduce_directly(expr, axis) |
| return _reduce_directly(expr, axis, where, *args) |
| |
| doc_str = """Create a {0} expression over axis. |
| |
| Parameters |
| ---------- |
| expr : PrimExpr |
| The source expression. |
| axis : IterVar |
| The reduction IterVar axis |
| where : optional, Expr |
| Filtering predicate of the reduction. |
| Returns |
| ------- |
| value : PrimExpr |
| The result value. |
| |
| Example |
| ------- |
| .. code-block:: python |
| |
| m = te.var("m") |
| n = te.var("n") |
| A = te.placeholder((m, n), name="A") |
| k = te.reduce_axis((0, n), name="k") |
| |
| # there are two way to use this {0} reducer: |
| # mode 1, accept (expr, axis, where) to produce an Reduce Expr |
| # tvm.{0} represents tvm.te.{0} or tvm.tir.{0}. |
| B = te.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B") |
| |
| # mode 2, simply use it with multiple Exprs: |
| {0}_res = tvm.{0}(m, n) |
| """ |
| reducer.__doc__ = doc_str.format(name) |
| return reducer |
| |
| |
| def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint): |
| """Backend function to allocate temporal workspace |
| |
| Parameters |
| ---------- |
| device_type : int |
| The device type which the space will be allocated. |
| |
| device_id : int |
| The device id which the space will be allocated. |
| |
| nbytes : int |
| The size of the space requested. |
| |
| dtype_code_hint : int |
| The type code of the array elements. Only used in certain backends such as OpenGL. |
| |
| dtype_bits_hint : int |
| The type bits of the array elements. Only used in certain backends such as OpenGL. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return call_intrin( |
| "handle", |
| "tir.TVMBackendAllocWorkspace", |
| device_type, |
| device_id, |
| nbytes, |
| dtype_code_hint, |
| dtype_bits_hint, |
| ) |
| |
| |
| def TVMBackendFreeWorkspace(device_type, device_id, ptr): |
| """Backend function to free temporal workspace. |
| |
| Parameters |
| ---------- |
| device_type : int |
| The device type which the space will be allocated. |
| |
| device_id : int |
| The device id which the space will be allocated. |
| |
| ptr : Var |
| The result allocated space pointer. |
| |
| Returns |
| ------- |
| call : PrimExpr |
| The call expression. |
| """ |
| return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr) |
| |
| |
| # pylint: disable=unnecessary-lambda |
| sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum") |
| min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore |
| max = comm_reducer(lambda x, y: _ffi_api._OpMax(x, y, None), min_value, name="max") # type: ignore |