blob: 03dfb9995e6d054cdc11f489f23413cdb54852c0 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Bring Your Own Datatypes custom datatype framework
TODO(@gussmith23 @hypercubestart) link to BYODT docs when they exist"""
import tvm
from tvm.runtime import convert, DataType
from tvm.tir.expr import (
Call as _Call,
Cast as _Cast,
FloatImm as _FloatImm,
BinaryOpExpr as _BinaryOpExpr,
)
from tvm.tir.op import call_pure_extern
from tvm._ffi import register_func as _register_func
from tvm.tir import call_intrin
def register(type_name, type_code):
"""Register a custom datatype with the given type name and type code
Currently, the type code is manually allocated by the user, and the user
must ensure that no two custom types share the same code. Generally, this
should be straightforward, as the user will be manually registering all of
their custom types.
Example:
.. code-block:: python
# Register a dtype named 'posites2' under type code 130.
tvm.target.datatype.register('posites2', 130)
Parameters
----------
type_name : str
The name of the custom datatype.
type_code : int
The type's code, which should be >= kCustomBegin. See
include/tvm/runtime/data_type.h.
"""
tvm.runtime._ffi_api._datatype_register(type_name, type_code)
def get_type_name(type_code):
"""Get the type name of a custom datatype from the type code.
Note that this only works for custom datatypes registered with
tvm.target.datatype.register(). It does not work for TVM-native types.
Example:
.. code-block:: python
tvm.target.datatype.register('posites2', 130)
assert tvm.target.datatype.get_type_name(130) == 'posites2'
Parameters
----------
type_code : int
The type code of the custom datatype.
Returns
-------
type_name : String
The name of the custom datatype.
"""
return tvm.runtime._ffi_api._datatype_get_type_name(type_code)
def get_type_code(type_name):
"""Get the type code of a custom datatype from its type name
Note that this only works for custom datatypes registered with
tvm.target.datatype.register(). It does not work for TVM-native types.
Example:
.. code-block:: python
tvm.target.datatype.register('posites2', 130)
assert tvm.target.datatype.get_type_code('posites2') == 130
Parameters
----------
type_name : str
The type name
Returns
-------
type_code : int
The type code of the custom datatype.
"""
return tvm.runtime._ffi_api._datatype_get_type_code(type_name)
def get_type_registered(type_code):
"""Returns true if a custom datatype is registered under the given type code
Example:
.. code-block:: python
tvm.target.datatype.register('posites2', 130)
assert tvm.target.datatype.get_type_registered(130)
Parameters
----------
type_code: int
The type code
Returns
-------
type_registered : bool
True if a custom datatype is registered under this type code, and false
otherwise.
"""
return tvm.runtime._ffi_api._datatype_get_type_registered(type_code)
def register_op(
lower_func, op_name, target, src_type_name, dest_type_name=None, intrinsic_name=None
):
"""Register a lowering function for a specific operator of a custom datatype
At build time, Relay must lower operators over custom datatypes into
operators it understands how to compile. For each custom datatype operator
which Relay finds while lowering custom datatypes, Relay expects to find a
user-defined lowering function. Users register their user-defined lowering
functions using this function.
Users should use create_lower_func to create their lowering function. It
should serve most use-cases.
Currently, this will work with Casts, intrinsics (e.g. sqrt, sigmoid), and
binary expressions (e.g. Add, Sub, Mul, Div).
See the LowerCustomDatatypes pass to see how registered functions are used.
Lowering Functions
------------------
TODO(@gussmith23) Get the terminology right here.
Lowering functions take in a Relay node, and should return a semantically
equivalent Relay node which Relay can build. This means that the returned
node should not contain any custom datatypes. Users should likely not need
to define lowering functions by hand -- see the helper function
create_lower_func.
Parameters
----------
lower_func : function
The lowering function to call. See create_lower_func.
op_name : str
The name of the operation which the function computes, given by its
class name (e.g. Add, LE, Cast, Call).
target : str
The name of codegen target.
src_type_name : str
The name of the custom datatype, e.g. posites2 (but not custom[posites2]32).
If op_name is not "Cast", then target type is guaranteed to be the same as src_type_name.
dest_type_name : str
If op_name is "Cast", then this is required and should be set to the dest datatype of
the argument to the Cast. If op_name is not "Cast", this is unused.
intrinsic_name : str
If op_name is "Call" and intrinsic_name is not None, then we assume the
op is a Call to an Intrinsic, and intrinsic_name is the intrinsic's
name.
"""
if op_name == "Cast":
assert dest_type_name is not None
lower_func_name = (
"tvm.datatype.lower."
+ target
+ "."
+ op_name
+ "."
+ dest_type_name
+ "."
+ src_type_name
)
elif op_name == "Call" and intrinsic_name is not None:
lower_func_name = (
"tvm.datatype.lower."
+ target
+ "."
+ op_name
+ ".intrin."
+ intrinsic_name
+ "."
+ src_type_name
)
else:
lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." + src_type_name
tvm._ffi.register_func(lower_func_name, lower_func)
def register_min_func(func, type_name):
"""Register the function that returns the minimum representable value of type_name.
Operators such as max pooling and argmax require the minimum
finite value representable by the datatype the op operating on.
Users can use this function to register a function that returns a TIR expression node
outputting the minimum representable value of their custom data type.
Users should use create_min_lower_func to create their lowering function. It
should serve most use-cases.
Note: for special cases when it is known that the custom datatype is representable
by a float, the user can create their own lowering func that returns a FloatImm.
The benefits are allowing optimizations such as rewrites to work as expected on custom
datatypes.
Parameters
----------
func : function
Input is an integer num_bits, should return a TIR expression node that
represents a scalar tensor of type custom[type_name]num_bits with the minimum
representable value.
type_name : str
The name of the custom datatype, e.g. posites2 (but not custom[posites2]32).
"""
_register_func("tvm.datatype.min." + type_name, func)
def create_min_lower_func(extern_func_map, type_name):
"""Returns a lowering function for getting the minimum value of a custom datatype.
Parameters
----------
extern_func_map : map
A map from bit lengths to the name of the extern "C" function to lower to.
type_name : string
The name of the custom datatype, e.g. posites2 (but not custom[posites2]32).
"""
def lower(num_bits):
dtype = f"custom[{type_name}]{num_bits}"
if num_bits not in extern_func_map:
raise RuntimeError("missing minimum function for {dtype}")
return call_pure_extern(dtype, extern_func_map[num_bits])
return lower
def create_lower_func(extern_func_map):
"""Returns a function which lowers an operation to a function call.
Parameters
----------
extern_func_map : map
If lowering a Cast, extern_func_map should be a map from tuples of
(src_bit_length, dest_bit_length) to the name of the extern "C" function to lower to.
Otherwise, for unary and binary ops, it should simply be a map
from bit_length to the name of the extern "C" function to lower to.
"""
def lower(op):
"""
Takes an op---either a Cast, Call, or a binary op (e.g. an Add) and returns a
call to the specified external function, passing the op's argument
or arguments. The return type of the call depends
on the type of the op: if it is a custom type, then a uint of the same
width as the custom type is returned. Otherwise, the type is
unchanged."""
dtype = op.dtype
t = DataType(dtype)
if get_type_registered(t.type_code):
dtype = "uint" + str(t.bits)
if t.lanes > 1:
dtype += "x" + str(t.lanes)
key = t.bits
if isinstance(op, _Cast):
src_bits = DataType(op.value.dtype).bits
key = (src_bits, t.bits)
if key not in extern_func_map:
raise RuntimeError(f"missing key {key} in extern_func_map for {op.astext()}")
if isinstance(op, _Cast):
return call_pure_extern(dtype, extern_func_map[key], op.value)
if isinstance(op, _FloatImm):
return call_pure_extern(dtype, extern_func_map[key], op.value)
if isinstance(op, _Call):
return call_pure_extern(dtype, extern_func_map[key], *op.args)
if isinstance(op, _BinaryOpExpr):
return call_pure_extern(dtype, extern_func_map[key], op.a, op.b)
raise RuntimeError(f"lowering unsupported op: {op.astext()}")
return lower
def lower_ite(ite_op):
"""Lowered if then else function that calls intrinsic if_then_else.
Unlike a function lowered by create_lower_func, this function
calls the tvm intrinsic if_then_else.
Parameters
----------
ite_op : Op
Takes an if then else op and returns a
call to tir.if_then_else function, passing the op's
arguments. The return type of the call if a uint of the same
width as the custom type is returned.
"""
dtype = ite_op.dtype
t = tvm.DataType(dtype)
assert get_type_registered(t.type_code)
dtype = "uint" + str(t.bits)
if t.lanes > 1:
dtype += "x" + str(t.lanes)
return call_intrin(
dtype,
"tir.if_then_else",
convert(ite_op.args[0]),
convert(ite_op.args[1]),
convert(ite_op.args[2]),
)
def lower_call_pure_extern(op):
"""Lowered call pure extern function that calls intrinsic call_pure_extern.
Unlike a function lowered by create_lower_func, this function
calls the tvm intrinsic call_pure_extern.
Parameters
----------
ite_op : Op
Takes a call_pure_extern op and returns a
call to tir.call_pure_extern function, passing the op's
arguments. The return type of the call if a uint of the same
width as the custom type is returned.
"""
dtype = op.dtype
t = tvm.DataType(dtype)
assert get_type_registered(t.type_code)
dtype = "uint" + str(t.bits)
if t.lanes > 1:
dtype += "x" + str(t.lanes)
return call_intrin(dtype, "tir.call_pure_extern", *op.args)