| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=invalid-name |
| """Common topi utilities""" |
| from __future__ import absolute_import as _abs |
| from numbers import Integral |
| |
| import tvm |
| from tvm import te |
| from tvm.tir import layout, bijective_layout |
| from . import tag, cpp |
| |
| |
| class InvalidShapeError(ValueError): |
| """Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)""" |
| |
| |
| def nchw_pack_layout(layout_info): |
| """Check whether the layout type is NCHWinic""" |
| return layout_info[:4] == "NCHW" and "c" in layout_info and "n" in layout_info |
| |
| |
| def nchw_xc_layout(layout_info): |
| """Check whether the layout type is NCHWxc""" |
| return layout_info[:4] == "NCHW" and "c" in layout_info and layout_info[4:-1].isnumeric() |
| |
| |
| def traverse_inline(s, final_op, callback): |
| """Traverse computation graph and do auto inline |
| |
| Parameters |
| ---------- |
| s: schedule |
| The schedule |
| final_op: Operation |
| The final output operator. |
| callback: callable |
| The callback function on each op |
| """ |
| visited = set() |
| |
| def _traverse(op): |
| if op in visited: |
| return |
| visited.add(op) |
| if tag.is_injective(op.tag): |
| if op not in s.outputs: |
| s[op].compute_inline() |
| for tensor in op.input_tensors: |
| if isinstance(tensor.op, tvm.te.ComputeOp): |
| _traverse(tensor.op) |
| callback(op) |
| |
| _traverse(final_op) |
| |
| |
| def prod(x): |
| """Get the product of every items in the tuple. |
| |
| Parameters |
| ---------- |
| x: tuple |
| Input tuple |
| |
| Returns |
| ------- |
| value : Expr |
| The result value |
| """ |
| if not x: |
| return tvm.tir.const(1, "int32") |
| res = x[0] |
| for i in range(1, len(x)): |
| res = res * x[i] |
| return res |
| |
| |
| def get_const_int(expr): |
| """Verifies expr is integer and get the constant value. |
| |
| Parameters |
| ---------- |
| expr : tvm.Expr or int |
| The input expression. |
| |
| Returns |
| ------- |
| out_value : int |
| The output. |
| """ |
| if isinstance(expr, Integral): |
| return expr |
| if not isinstance(expr, tvm.tir.IntImm): |
| ana = tvm.arith.Analyzer() |
| expr = ana.simplify(expr) |
| if not isinstance(expr, tvm.tir.IntImm): |
| raise ValueError("Expect value to be constant int") |
| return int(expr.value) |
| |
| |
| def get_const_float(expr): |
| """Verifies expr is a floating point and get the constant value. |
| |
| Parameters |
| ---------- |
| expr : tvm.Expr or float |
| The input expression. |
| |
| Returns |
| ------- |
| out_value : float |
| The output. |
| """ |
| if isinstance(expr, float): |
| return float(expr) |
| if not isinstance(expr, tvm.tir.FloatImm): |
| ana = tvm.arith.Analyzer() |
| expr = ana.simplify(expr) |
| if not isinstance(expr, tvm.tir.FloatImm): |
| raise ValueError("Expect value to be constant float") |
| return float(expr.value) |
| |
| |
| def equal_const_int(expr, value): |
| """Returns if expr equals value. |
| |
| Parameters |
| ---------- |
| expr : tvm.Expr |
| The input expression. |
| |
| Returns |
| ------- |
| equal : bool |
| Whether they equals. |
| """ |
| if isinstance(expr, Integral): |
| return expr == value |
| if not isinstance(expr, tvm.tir.IntImm): |
| ana = tvm.arith.Analyzer() |
| expr = ana.simplify(expr) |
| if not isinstance(expr, tvm.tir.IntImm): |
| return False |
| return expr.value == value |
| |
| |
| def get_const_tuple(in_tuple): |
| """Verifies input tuple is IntImm or Var, returns tuple of int or Var. |
| |
| Parameters |
| ---------- |
| in_tuple : tuple of Expr |
| The input. |
| |
| Returns |
| ------- |
| out_tuple : tuple of int |
| The output. |
| """ |
| ret = [] |
| ana = None |
| for elem in in_tuple: |
| if isinstance(elem, (tvm.tir.Var, tvm.tir.expr.Any)): |
| ret.append(elem) |
| elif not isinstance(elem, (tvm.tir.IntImm, int)): |
| ana = tvm.arith.Analyzer() if ana is None else ana |
| elem = ana.simplify(elem) |
| if not isinstance(elem, tvm.tir.IntImm): |
| ret.append(elem) |
| else: |
| ret.append(get_const_int(elem)) |
| else: |
| ret.append(get_const_int(elem)) |
| return tuple(ret) |
| |
| |
| def get_float_tuple(in_tuple): |
| """Verifies input tuple is FloatImm, returns tuple of float. |
| |
| Parameters |
| ---------- |
| in_tuple : tuple of Expr |
| The input. |
| |
| Returns |
| ------- |
| out_tuple : tuple of float |
| The output. |
| """ |
| return tuple(get_const_float(elem) for elem in in_tuple) |
| |
| |
| def simplify(expr): |
| """Simplify the expression if it is Expr, directly return if it is int. |
| |
| Parameters |
| ---------- |
| expr : Expr or int |
| The input. |
| |
| Returns |
| ------- |
| out : Expr or int |
| The simplified output |
| """ |
| return tvm.arith.Analyzer().simplify(expr) if isinstance(expr, tvm.tir.PrimExpr) else expr |
| |
| |
| def ravel_index(indices, shape): |
| """Flatten the index tuple to 1D |
| |
| Parameters |
| ---------- |
| indices : tuple of int or tvm.tir.IntImm |
| The input coordinates |
| |
| shape : tuple of int |
| Shape of the tensor. |
| |
| Returns |
| ------- |
| idx : int or Expr |
| The index after flattening |
| """ |
| idx = None |
| for i, (shape_val, ind) in enumerate(zip(shape, indices)): |
| if i != 0: |
| idx = idx * shape_val + ind |
| else: |
| idx = ind |
| return idx |
| |
| |
| def unravel_index(idx, shape): |
| """Convert the flattened ind to the coordinate array |
| |
| Parameters |
| ---------- |
| idx : int or tvm.tir.IntImm |
| The 1D index |
| |
| shape : tuple of int |
| Shape of the tensor |
| |
| Returns |
| ------- |
| indices : tuple of int or tvm.tir.IntImm |
| Corresponding coordinate of the 1D index |
| """ |
| idxd = tvm.tir.indexdiv |
| idxm = tvm.tir.indexmod |
| indices = [] |
| for i in range(len(shape) - 1, -1, -1): |
| indices.append(idxm(idx, shape[i])) |
| idx = idxd(idx, shape[i]) |
| indices = indices[::-1] |
| return indices |
| |
| |
| def const_matrix(matrix, name="const_matrix"): |
| """convert a const numpy 2-dimensional matrix to tvm tensor |
| |
| Parameters |
| ---------- |
| matrix: numpy.ndarray |
| Const input array |
| name: str, optional |
| The name of output op |
| |
| Returns |
| ------- |
| tensor: Tensor |
| The created tensor |
| """ |
| row, col = matrix.shape |
| dtype = str(matrix.dtype) |
| idxm = tvm.tir.indexmod |
| |
| def select_array(i, j): |
| now = tvm.tir.const(0.0, dtype) |
| for ii in range(row): |
| for jj in range(col): |
| now = tvm.tir.Select( |
| tvm.tir.all(idxm(i, row) == ii, idxm(j, col) == jj), |
| tvm.tir.const(matrix[ii][jj], dtype), |
| now, |
| ) |
| return now |
| |
| return te.compute(matrix.shape, select_array, name=name) |
| |
| |
| def get_max_power2_factor(n, max_value=None): |
| """Get max factor of n in power of 2. If max_value is specificed, max factor |
| value will be no more max_value, |
| |
| Parameter |
| --------- |
| n : int |
| The input value |
| |
| max_value : int, optional |
| The max value for the factor |
| |
| Returns |
| ------- |
| factor : int |
| The max factor in power of 2. |
| """ |
| x = 1 |
| while n % 2 == 0: |
| if max_value is not None and max_value < x * 2: |
| break |
| x *= 2 |
| n /= 2 |
| return x |
| |
| |
| def get_shape(src_shape, src_layout, dst_layout): |
| """Given a source shape, a source layout and a destination layout, infer |
| the destination shape. |
| |
| Parameter |
| --------- |
| src_shape : tuple of int or IntImm |
| Source shape |
| |
| src_layout : str or Layout |
| Source layout |
| |
| dst_layout : str or Layout |
| Destination layout |
| |
| Returns |
| ------- |
| dst_shape : tuple of int |
| Destination shape |
| """ |
| if src_layout == dst_layout: |
| return get_const_tuple(src_shape) |
| |
| if isinstance(src_layout, str): |
| src_layout = layout(src_layout) |
| if isinstance(dst_layout, str): |
| dst_layout = layout(dst_layout) |
| |
| assert len(src_layout) == len(dst_layout), "Incompatible layout %s vs %s" % ( |
| src_layout, |
| dst_layout, |
| ) |
| |
| layout_mapping = bijective_layout(src_layout, dst_layout) |
| dst_indices = layout_mapping.forward_index(tvm.runtime.convert(list(range(len(src_layout))))) |
| |
| return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices])) |
| |
| |
| def within_index(b, e, s, i): |
| """Return a boolean value that indicates if i is within the given index. |
| |
| Parameters |
| ---------- |
| b : Expr |
| beginning of the index |
| |
| e : Expr |
| end of the index |
| |
| s : Expr |
| strides of index |
| |
| i : Expr |
| array position |
| |
| Returns |
| ------- |
| selected: Expr |
| bool expression that is True is the array position would be selected |
| by the index and False otherwise |
| """ |
| bc = tvm.tir.Select(s < 0, i <= e, i < b) |
| ec = tvm.tir.Select(s < 0, i > b, i >= e) |
| ss = te.if_then_else(s < 0, ((i - e) + (e % te.abs(s)) + 1) % te.abs(s), (i - b) % s) |
| return tvm.tir.Select(tvm.tir.Or(bc, ec), tvm.tir.const(False), ss.equal(0)) |
| |
| |
| def make_idx(b, e, s, z, i): |
| """Return the array position in the selection that corresponds to an |
| array position in the full array. |
| |
| The returned value is only meaningful if within_index() returns True |
| for the same set of parameters. |
| |
| Parameters |
| ---------- |
| b : Expr |
| beginning of the index |
| |
| e : Expr |
| end of the index |
| |
| s : Expr |
| strides of index |
| |
| z : Expr |
| size of the indexed dimension |
| |
| i : Expr |
| array position |
| |
| Returns |
| ------- |
| postion: Expr |
| int expression that corresponds to an array position in the selection. |
| """ |
| bc = tvm.tir.Select(s < 0, i <= e, i < b) |
| ec = tvm.tir.Select(s < 0, i > b, i >= e) |
| |
| # Clamp to array size |
| b = tvm.tir.Select(z < b, z - 1, b) |
| |
| ss = tvm.tir.if_then_else(s < 0, (b - i) // te.abs(s), (i - b) // s) |
| return tvm.tir.if_then_else(tvm.tir.Or(bc, ec), 88, ss) |
| |
| |
| def is_empty_shape(shape): |
| """Check whether an input shape has dimesion with size 0. |
| |
| Parameter |
| --------- |
| shape : list of Expr |
| Input shape |
| |
| Returns |
| ------- |
| is_empty: bool |
| Whether input shape is empty or has dimesion with size 0. |
| """ |
| return cpp.util.is_empty_shape(shape) |