| |
| .. DO NOT EDIT. THIS FILE WAS AUTOMATICALLY GENERATED BY |
| .. TVM'S MONKEY-PATCHED VERSION OF SPHINX-GALLERY. TO MAKE |
| .. CHANGES, EDIT THE SOURCE PYTHON FILE: |
| .. "how_to/work_with_schedules/intrin_math.py" |
| |
| .. only:: html |
| |
| .. note:: |
| :class: sphx-glr-download-link-note |
| |
| This tutorial can be used interactively with Google Colab! You can also click |
| :ref:`here <sphx_glr_download_how_to_work_with_schedules_intrin_math.py>` to run the Jupyter notebook locally. |
| |
| .. image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg |
| :align: center |
| :target: https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/1e482ba1190961191e3a0bdbd0585faa/intrin_math.ipynb |
| :width: 300px |
| |
| .. rst-class:: sphx-glr-example-title |
| |
| .. _sphx_glr_how_to_work_with_schedules_intrin_math.py: |
| |
| |
| Intrinsics and Math Functions |
| ============================= |
| **Author**: `Tianqi Chen <https://tqchen.github.io>`_ |
| |
| While TVM supports basic arithmetic operations. In many cases |
| usually we will need more complicated builtin functions. |
| For example :code:`exp` to take the exponential of the function. |
| |
| These functions are target system dependent and may have different |
| names of different target platforms. In this tutorial, we will learn |
| how we can invoke these target specific functions, and how we can unify |
| the interface via TVM's intrinsic API. |
| |
| .. GENERATED FROM PYTHON SOURCE LINES 31-39 |
| |
| .. code-block:: default |
| |
| from __future__ import absolute_import, print_function |
| |
| import numpy as np |
| |
| import tvm |
| from tvm import te |
| from tvm.ir import register_op_attr, register_intrin_lowering |
| |
| |
| |
| |
| |
| |
| |
| |
| .. GENERATED FROM PYTHON SOURCE LINES 40-47 |
| |
| Direct Declare Extern Math Call |
| ------------------------------- |
| The most straight-forward way to call target specific function is via |
| extern function call construct in tvm. |
| In the following example, we use :any:`tvm.tir.call_pure_extern` to call |
| :code:`__expf` function, which is only available under CUDA. |
| |
| |
| .. GENERATED FROM PYTHON SOURCE LINES 47-58 |
| |
| .. code-block:: default |
| |
| n = te.var("n") |
| A = te.placeholder((n,), name="A") |
| B = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("float32", "__expf", A[i]), name="B") |
| s = te.create_schedule(B.op) |
| num_thread = 64 |
| bx, tx = s[B].split(B.op.axis[0], factor=num_thread) |
| s[B].bind(bx, te.thread_axis("blockIdx.x")) |
| s[B].bind(tx, te.thread_axis("threadIdx.x")) |
| f = tvm.build(s, [A, B], "cuda", name="myexp") |
| print(f.imported_modules[0].get_source()) |
| |
| |
| |
| |
| |
| .. rst-class:: sphx-glr-script-out |
| |
| .. code-block:: none |
| |
| |
| #if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ |
| (__CUDACC_VER_MAJOR__ > 11)) |
| #define TVM_ENABLE_L2_PREFETCH 1 |
| #else |
| #define TVM_ENABLE_L2_PREFETCH 0 |
| #endif |
| |
| #ifdef _WIN32 |
| using uint = unsigned int; |
| using uchar = unsigned char; |
| using ushort = unsigned short; |
| using int64_t = long long; |
| using uint64_t = unsigned long long; |
| #else |
| #define uint unsigned int |
| #define uchar unsigned char |
| #define ushort unsigned short |
| #define int64_t long long |
| #define uint64_t unsigned long long |
| #endif |
| extern "C" __global__ void __launch_bounds__(64) myexp_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1); |
| extern "C" __global__ void __launch_bounds__(64) myexp_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1) { |
| if (((int)blockIdx.x) < (n >> 6)) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = __expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } else { |
| if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = __expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } |
| } |
| } |
| |
| |
| |
| |
| |
| |
| .. GENERATED FROM PYTHON SOURCE LINES 59-71 |
| |
| Unified Intrinsic Call |
| ---------------------- |
| The above code verifies that direct external call can be used to |
| call into device specific functions. |
| However, the above way only works for CUDA target with float type. |
| Ideally, we want to write same code for any device and any data type. |
| |
| TVM intrinsic provides the user a mechanism to achieve this, and this |
| is the recommended way to solve the problem. |
| The following code use te.exp instead, which create an intrinsic call |
| :py::func:`tvm.te.exp` to do the exponential. |
| |
| |
| .. GENERATED FROM PYTHON SOURCE LINES 71-81 |
| |
| .. code-block:: default |
| |
| n = te.var("n") |
| A = te.placeholder((n,), name="A") |
| B = te.compute(A.shape, lambda i: te.exp(A[i]), name="B") |
| s = te.create_schedule(B.op) |
| num_thread = 64 |
| bx, tx = s[B].split(B.op.axis[0], factor=num_thread) |
| s[B].bind(bx, te.thread_axis("blockIdx.x")) |
| s[B].bind(tx, te.thread_axis("threadIdx.x")) |
| fcuda = tvm.build(s, [A, B], "cuda", name="myexp") |
| print(fcuda.imported_modules[0].get_source()) |
| |
| |
| |
| |
| .. rst-class:: sphx-glr-script-out |
| |
| .. code-block:: none |
| |
| |
| #if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ |
| (__CUDACC_VER_MAJOR__ > 11)) |
| #define TVM_ENABLE_L2_PREFETCH 1 |
| #else |
| #define TVM_ENABLE_L2_PREFETCH 0 |
| #endif |
| |
| #ifdef _WIN32 |
| using uint = unsigned int; |
| using uchar = unsigned char; |
| using ushort = unsigned short; |
| using int64_t = long long; |
| using uint64_t = unsigned long long; |
| #else |
| #define uint unsigned int |
| #define uchar unsigned char |
| #define ushort unsigned short |
| #define int64_t long long |
| #define uint64_t unsigned long long |
| #endif |
| extern "C" __global__ void __launch_bounds__(64) myexp_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1); |
| extern "C" __global__ void __launch_bounds__(64) myexp_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1) { |
| if (((int)blockIdx.x) < (n >> 6)) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = __expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } else { |
| if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = __expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } |
| } |
| } |
| |
| |
| |
| |
| |
| |
| .. GENERATED FROM PYTHON SOURCE LINES 82-85 |
| |
| We can find that the code works for both CUDA and opencl. |
| The same te.exp can also be used for float64 data types. |
| |
| |
| .. GENERATED FROM PYTHON SOURCE LINES 85-88 |
| |
| .. code-block:: default |
| |
| fopencl = tvm.build(s, [A, B], "opencl", name="myexp") |
| print(fopencl.imported_modules[0].get_source()) |
| |
| |
| |
| |
| |
| .. rst-class:: sphx-glr-script-out |
| |
| .. code-block:: none |
| |
| // Function: myexp_kernel |
| __kernel void myexp_kernel(__global float* restrict A, __global float* restrict B, int n, int stride, int stride_1); |
| __kernel void myexp_kernel(__global float* restrict A, __global float* restrict B, int n, int stride, int stride_1) { |
| if ((convert_int(get_group_id(0))) < (n >> 6)) { |
| B[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride)] = exp(A[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride_1)]); |
| } else { |
| if ((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) < n) { |
| B[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride)] = exp(A[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride_1)]); |
| } |
| } |
| } |
| |
| |
| |
| |
| |
| |
| .. GENERATED FROM PYTHON SOURCE LINES 89-98 |
| |
| Intrinsic Lowering Rule |
| ----------------------- |
| When :py:func:`tvm.te.exp` is called, TVM creates an intrinsic Call Expr. |
| TVM uses transformation rules to transform the intrinsic |
| call to device specific extern calls. |
| |
| TVM also allows user to customize the rules during runtime. |
| The following example customizes CUDA lowering rule for :code:`exp`. |
| |
| |
| .. GENERATED FROM PYTHON SOURCE LINES 98-118 |
| |
| .. code-block:: default |
| |
| |
| |
| def my_cuda_math_rule(op): |
| """Customized CUDA intrinsic lowering rule""" |
| assert isinstance(op, tvm.tir.Call) |
| name = op.op.name |
| assert name.startswith("tir.") |
| dispatch_name = name[4:] |
| if op.dtype == "float32": |
| # call float function |
| return tvm.tir.call_pure_extern("float32", "%sf" % dispatch_name, op.args[0]) |
| elif op.dtype == "float64": |
| # call double function |
| return tvm.tir.call_pure_extern("float32", dispatch_name, op.args[0]) |
| else: |
| # cannot do translation, return self. |
| return op |
| |
| |
| register_intrin_lowering("tir.exp", target="cuda", f=my_cuda_math_rule, level=99) |
| |
| |
| |
| |
| .. rst-class:: sphx-glr-script-out |
| |
| .. code-block:: none |
| |
| |
| <function my_cuda_math_rule at 0x7fd4f3b94af0> |
| |
| |
| |
| .. GENERATED FROM PYTHON SOURCE LINES 119-124 |
| |
| Register the rule to TVM with override option to override existing rule. |
| Notice the difference between the printed code from previous one: |
| our new rule uses math function :code:`expf` instead of |
| fast math version :code:`__expf`. |
| |
| |
| .. GENERATED FROM PYTHON SOURCE LINES 124-127 |
| |
| .. code-block:: default |
| |
| fcuda = tvm.build(s, [A, B], "cuda", name="myexp") |
| print(fcuda.imported_modules[0].get_source()) |
| |
| |
| |
| |
| |
| .. rst-class:: sphx-glr-script-out |
| |
| .. code-block:: none |
| |
| |
| #if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ |
| (__CUDACC_VER_MAJOR__ > 11)) |
| #define TVM_ENABLE_L2_PREFETCH 1 |
| #else |
| #define TVM_ENABLE_L2_PREFETCH 0 |
| #endif |
| |
| #ifdef _WIN32 |
| using uint = unsigned int; |
| using uchar = unsigned char; |
| using ushort = unsigned short; |
| using int64_t = long long; |
| using uint64_t = unsigned long long; |
| #else |
| #define uint unsigned int |
| #define uchar unsigned char |
| #define ushort unsigned short |
| #define int64_t long long |
| #define uint64_t unsigned long long |
| #endif |
| extern "C" __global__ void __launch_bounds__(64) myexp_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1); |
| extern "C" __global__ void __launch_bounds__(64) myexp_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1) { |
| if (((int)blockIdx.x) < (n >> 6)) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } else { |
| if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } |
| } |
| } |
| |
| |
| |
| |
| |
| |
| .. GENERATED FROM PYTHON SOURCE LINES 128-134 |
| |
| Add Your Own Intrinsic |
| ---------------------- |
| If there is an intrinsic that is not provided by TVM. |
| User can easily add new intrinsic by using the intrinsic rule system. |
| The following example add an intrinsic :code:`mylog` to the system. |
| |
| |
| .. GENERATED FROM PYTHON SOURCE LINES 134-166 |
| |
| .. code-block:: default |
| |
| |
| |
| def mylog(x): |
| """customized log intrinsic function""" |
| return tvm.tir.call_intrin(x.dtype, "tir.mylog", x) |
| |
| |
| def my_cuda_mylog_rule(op): |
| """CUDA lowering rule for log""" |
| if op.dtype == "float32": |
| return tvm.tir.call_pure_extern("float32", "logf", op.args[0]) |
| elif op.dtype == "float64": |
| return tvm.tir.call_pure_extern("float64", "log", op.args[0]) |
| else: |
| return op |
| |
| |
| # new op registration is triggered by registering an attribute of the op |
| register_op_attr("tir.mylog", "TCallEffectKind", tvm.tir.CallEffectKind.Pure) |
| register_intrin_lowering("tir.mylog", target="cuda", f=my_cuda_mylog_rule, level=99) |
| |
| n = te.var("n") |
| A = te.placeholder((n,), name="A") |
| B = te.compute(A.shape, lambda i: mylog(A[i]), name="B") |
| s = te.create_schedule(B.op) |
| num_thread = 64 |
| bx, tx = s[B].split(B.op.axis[0], factor=num_thread) |
| s[B].bind(bx, te.thread_axis("blockIdx.x")) |
| s[B].bind(tx, te.thread_axis("threadIdx.x")) |
| fcuda = tvm.build(s, [A, B], "cuda", name="mylog") |
| print(fcuda.imported_modules[0].get_source()) |
| |
| |
| |
| |
| |
| .. rst-class:: sphx-glr-script-out |
| |
| .. code-block:: none |
| |
| |
| #if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ |
| (__CUDACC_VER_MAJOR__ > 11)) |
| #define TVM_ENABLE_L2_PREFETCH 1 |
| #else |
| #define TVM_ENABLE_L2_PREFETCH 0 |
| #endif |
| |
| #ifdef _WIN32 |
| using uint = unsigned int; |
| using uchar = unsigned char; |
| using ushort = unsigned short; |
| using int64_t = long long; |
| using uint64_t = unsigned long long; |
| #else |
| #define uint unsigned int |
| #define uchar unsigned char |
| #define ushort unsigned short |
| #define int64_t long long |
| #define uint64_t unsigned long long |
| #endif |
| extern "C" __global__ void __launch_bounds__(64) mylog_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1); |
| extern "C" __global__ void __launch_bounds__(64) mylog_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1) { |
| if (((int)blockIdx.x) < (n >> 6)) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = logf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } else { |
| if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = logf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } |
| } |
| } |
| |
| |
| |
| |
| |
| |
| .. GENERATED FROM PYTHON SOURCE LINES 167-174 |
| |
| Summary |
| ------- |
| - TVM can call extern target dependent math function. |
| - Use intrinsic to defined a unified interface for the functions. |
| - For more intrinsics available in tvm, take a look at :any:`tvm.tir` |
| - You can customize the intrinsic behavior by defining your own rules. |
| |
| |
| |
| .. _sphx_glr_download_how_to_work_with_schedules_intrin_math.py: |
| |
| .. only:: html |
| |
| .. container:: sphx-glr-footer sphx-glr-footer-example |
| |
| |
| .. container:: sphx-glr-download sphx-glr-download-python |
| |
| :download:`Download Python source code: intrin_math.py <intrin_math.py>` |
| |
| .. container:: sphx-glr-download sphx-glr-download-jupyter |
| |
| :download:`Download Jupyter notebook: intrin_math.ipynb <intrin_math.ipynb>` |
| |
| |
| .. only:: html |
| |
| .. rst-class:: sphx-glr-signature |
| |
| `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_ |