blob: 79b2db5d94a84d9cdb82e628bcef1440744ec57e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tensor intrinsics"""
import tvm._ffi
import tvm.tir
from tvm.runtime import Object, convert
from tvm.ir import Range
from .tensor import PlaceholderOp
from . import tensor as _tensor
from . import _ffi_api
def _get_region(tslice):
region = []
for idx in tslice.indices:
if isinstance(idx, slice):
assert idx.step is None
region.append(Range(idx.start, idx.stop))
else:
if isinstance(idx, tvm.tir.IterVar):
begin = idx.var
else:
begin = idx
region.append(Range.from_min_extent(begin, 1))
return region
@tvm._ffi.register_object
class TensorIntrin(Object):
"""Tensor intrinsic functions for certain computation.
See Also
--------
decl_tensor_intrin: Construct a TensorIntrin
"""
def __call__(self, *args, **kwargs):
tensors = [x.tensor for x in args if isinstance(x, _tensor.TensorSlice)]
scalar_inputs = [x for x in args if not isinstance(x, _tensor.TensorSlice)]
regions = [_get_region(x) for x in args if isinstance(x, _tensor.TensorSlice)]
reduce_axis = []
if "reduce_axis" in kwargs:
reduce_axis = kwargs["reduce_axis"]
if not isinstance(reduce_axis, (list, tuple)):
reduce_axis = [reduce_axis]
reduce_axis = convert(reduce_axis)
if scalar_inputs:
scalar_inputs = convert(scalar_inputs)
return _ffi_api.TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs)
def decl_tensor_intrin(
op, fcompute, name="tensor_intrin", binds=None, scalar_params=None, default_buffer_params=None
):
"""Declare a tensor intrinsic function.
Parameters
----------
op: Operation
The symbolic description of the intrinsic operation
fcompute: lambda function of inputs, outputs-> stmt
Specifies the IR statement to do the computation.
See the following note for function signature of fcompute
.. note::
**Parameters**
- **ins** (list of :any:`Buffer`) - Placeholder for each inputs
- **outs** (list of :any:`Buffer`) - Placeholder for each outputs
**Returns**
- **stmt** (:any:`Stmt`, or tuple of three stmts)
- If a single stmt is returned, it represents the body
- If tuple of three stmts are returned they corresponds to body,
reduce_init, reduce_update
name: str, optional
The name of the intrinsic.
binds: dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
scalar_params: a list of variables used by op, whose values will be passed
as scalar_inputs when the tensor intrinsic is called.
default_buffer_params: Optional[dict]
Dictionary of buffer arguments to be passed when constructing a buffer.
Returns
-------
intrin: TensorIntrin
A TensorIntrin that can be used in tensorize schedule.
"""
if not isinstance(op, _tensor.Operation):
raise TypeError("expect Operation")
inputs = op.input_tensors
binds = binds if binds else {}
tensors = list(inputs)
for i in range(op.num_outputs):
tensors.append(op.output(i))
binds_list = []
for t in inputs:
if not isinstance(t.op, PlaceholderOp):
raise ValueError("Do not yet support composition op")
default_buffer_params = {} if default_buffer_params is None else default_buffer_params
for t in tensors:
buf = (
binds[t]
if t in binds
else tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name, **default_buffer_params)
)
binds_list.append(buf)
if scalar_params:
body = fcompute(binds_list[: len(inputs)], binds_list[len(inputs) :], scalar_params)
else:
body = fcompute(binds_list[: len(inputs)], binds_list[len(inputs) :])
scalar_params = []
if isinstance(body, (tvm.tir.PrimExpr, tvm.tir.Stmt)):
body = [body]
body = [tvm.tir.Evaluate(x) if isinstance(x, tvm.tir.PrimExpr) else x for x in body]
if len(body) < 3:
body += [None] * (3 - len(body))
return _ffi_api.TensorIntrin(name, op, inputs, binds_list, scalar_params, *body)