blob: 657342df1e6bb4324d927ad279adde6f8170712b [file] [log] [blame]
.. DO NOT EDIT. THIS FILE WAS AUTOMATICALLY GENERATED BY
.. TVM'S MONKEY-PATCHED VERSION OF SPHINX-GALLERY. TO MAKE
.. CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "how_to/tune_with_autotvm/tune_conv2d_cuda.py"
.. only:: html
.. note::
:class: sphx-glr-download-link-note
This tutorial can be used interactively with Google Colab! You can also click
:ref:`here <sphx_glr_download_how_to_tune_with_autotvm_tune_conv2d_cuda.py>` to run the Jupyter notebook locally.
.. image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg
:align: center
:target: https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/732ed130cbc15432e737da8cc47e1734/tune_conv2d_cuda.ipynb
:width: 300px
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_how_to_tune_with_autotvm_tune_conv2d_cuda.py:
Tuning High Performance Convolution on NVIDIA GPUs
=========================================================================
**Author**: `Lianmin Zheng <https://github.com/merrymercy>`_
This is an advanced tutorial for writing high performance tunable template for
NVIDIA GPU. By running auto-tuner on this template, we can outperform the
vendor provided library CuDNN in many cases.
Note that this tutorial will not run on Windows or recent versions of macOS. To
get it to run, you will need to wrap the body of this tutorial in a :code:`if
__name__ == "__main__":` block.
.. GENERATED FROM PYTHON SOURCE LINES 32-50
Install dependencies
--------------------
To use autotvm package in tvm, we need to install some extra dependencies.
(change "3" to "2" if you use python2):
.. code-block:: bash
pip3 install --user psutil xgboost tornado cloudpickle
To make TVM run faster in tuning, it is recommended to use cython
as FFI of tvm. In the root directory of tvm, execute
.. code-block:: bash
pip3 install --user cython
sudo make cython3
Now return to python code. Import packages.
.. GENERATED FROM PYTHON SOURCE LINES 50-62
.. code-block:: default
import logging
import sys
import numpy as np
import tvm
from tvm import te, topi, testing
from tvm.topi.testing import conv2d_nchw_python
import tvm.testing
from tvm import autotvm
.. GENERATED FROM PYTHON SOURCE LINES 66-88
Step 1: Define the search space
--------------------------------
There are plenty of useful schedule primitives in tvm. You can also find
some tutorials that describe them in more details, such as
(1). :ref:`opt-conv-gpu`
(2). `Optimizing DepthwiseConv on NVIDIA GPU <https://tvm.apache.org/2017/08/22/Optimize-Deep-Learning-GPU-Operators-with-TVM-A-Depthwise-Convolution-Example>`_
However, their implementations are manually tuned for some special input
shapes. In this section, we build a large enough space to cover
the techniques used in these tutorials. Then we rely on the efficient auto-tuner
to search through this space and pick some good configurations.
If you are familiar with writing cuda schedule, you can find the following
template is very general. Actually this template can be easily modified
to tune other operators such as depthwise convolution and GEMM.
In order to fully understand this template, you should be familiar with
the schedule primitives and auto tuning API. You can refer to the above
tutorials and :ref:`autotvm tutorial <tutorial-autotvm-matmul-x86>`
It is worth noting that the search space for a conv2d operator
can be very large (at the level of 10^9 for some input shapes)
.. GENERATED FROM PYTHON SOURCE LINES 88-178
.. code-block:: default
@autotvm.template("tutorial/conv2d_no_batching")
def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
assert N == 1, "Only consider batch_size = 1 in this template"
data = te.placeholder((N, CI, H, W), name="data")
kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32")
s = te.create_schedule([conv.op])
##### space definition begin #####
n, f, y, x = s[conv].op.axis
rc, ry, rx = s[conv].op.reduce_axis
cfg = autotvm.get_config()
cfg.define_split("tile_f", f, num_outputs=4)
cfg.define_split("tile_y", y, num_outputs=4)
cfg.define_split("tile_x", x, num_outputs=4)
cfg.define_split("tile_rc", rc, num_outputs=3)
cfg.define_split("tile_ry", ry, num_outputs=3)
cfg.define_split("tile_rx", rx, num_outputs=3)
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
cfg.define_knob("unroll_explicit", [0, 1])
##### space definition end #####
# inline padding
pad_data = s[conv].op.input_tensors[0]
s[pad_data].compute_inline()
data, raw_data = pad_data, data
output = conv
OL = s.cache_write(conv, "local")
# create cache stage
AA = s.cache_read(data, "shared", [OL])
WW = s.cache_read(kernel, "shared", [OL])
AL = s.cache_read(AA, "local", [OL])
WL = s.cache_read(WW, "local", [OL])
# tile and bind spatial axes
n, f, y, x = s[output].op.axis
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
kernel_scope = n # this is the scope to attach global config inside this kernel
s[output].bind(bf, te.thread_axis("blockIdx.z"))
s[output].bind(by, te.thread_axis("blockIdx.y"))
s[output].bind(bx, te.thread_axis("blockIdx.x"))
s[output].bind(vf, te.thread_axis("vthread"))
s[output].bind(vy, te.thread_axis("vthread"))
s[output].bind(vx, te.thread_axis("vthread"))
s[output].bind(tf, te.thread_axis("threadIdx.z"))
s[output].bind(ty, te.thread_axis("threadIdx.y"))
s[output].bind(tx, te.thread_axis("threadIdx.x"))
s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
s[OL].compute_at(s[output], tx)
# tile reduction axes
n, f, y, x = s[OL].op.axis
rc, ry, rx = s[OL].op.reduce_axis
rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
ryo, rym, ryi = cfg["tile_rx"].apply(s, OL, ry)
rxo, rxm, rxi = cfg["tile_ry"].apply(s, OL, rx)
s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)
s[AA].compute_at(s[OL], rxo)
s[WW].compute_at(s[OL], rxo)
s[AL].compute_at(s[OL], rxm)
s[WL].compute_at(s[OL], rxm)
# cooperative fetching
for load in [AA, WW]:
n, f, y, x = s[load].op.axis
fused = s[load].fuse(n, f, y, x)
tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
s[load].bind(tz, te.thread_axis("threadIdx.z"))
s[load].bind(ty, te.thread_axis("threadIdx.y"))
s[load].bind(tx, te.thread_axis("threadIdx.x"))
# tune unroll
s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
return s, [raw_data, kernel, conv]
.. GENERATED FROM PYTHON SOURCE LINES 179-186
Step 2: Search through the space
---------------------------------
We pick the last layer on resnet as test case.
Since our space is very large, :code:`XGBoostTuner` is most suitable
for our case. Here we only do 20 trials for demonstration.
In practice, making 1000 trials usually can find some good kernels
for this template
.. GENERATED FROM PYTHON SOURCE LINES 186-221
.. code-block:: default
# logging config (for printing tuning log to screen)
logging.getLogger("autotvm").setLevel(logging.DEBUG)
logging.getLogger("autotvm").addHandler(logging.StreamHandler(sys.stdout))
# the last layer in resnet
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
task = autotvm.task.create(
"tutorial/conv2d_no_batching", args=(N, H, W, CO, CI, KH, KW, strides, padding), target="cuda"
)
print(task.config_space)
# Use local gpu, measure 10 times for every config to reduce variance
# The timeout of compiling a program is 10 seconds, the timeout for running is 4 seconds
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4),
)
record_file = None
# Begin tuning, log records to file `conv2d.log`
# During tuning we will also try many invalid configs, so you are expected to
# see many error reports. As long as you can see non-zero GFLOPS, it is okay.
# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following lines to run it by yourself.
# tuner = autotvm.tuner.XGBTuner(task)
# record_file = "conv2d.log"
# tuner.tune(
# n_trial=5,
# measure_option=measure_option,
# callbacks=[autotvm.callback.log_to_file(record_file)],
# )
.. rst-class:: sphx-glr-script-out
.. code-block:: none
ConfigSpace (len=10454400, range_length=10454400, space_map=
0 tile_f: Split(policy=factors, product=512, num_outputs=4) len=220
1 tile_y: Split(policy=factors, product=7, num_outputs=4) len=4
2 tile_x: Split(policy=factors, product=7, num_outputs=4) len=4
3 tile_rc: Split(policy=factors, product=512, num_outputs=3) len=55
4 tile_ry: Split(policy=factors, product=3, num_outputs=3) len=3
5 tile_rx: Split(policy=factors, product=3, num_outputs=3) len=3
6 auto_unroll_max_step: OtherOption([0, 512, 1500]) len=3
7 unroll_explicit: OtherOption([0, 1]) len=2
)
.. GENERATED FROM PYTHON SOURCE LINES 222-224
Finally we can inspect the best config from log file, check correctness,
and measure running time.
.. GENERATED FROM PYTHON SOURCE LINES 224-254
.. code-block:: default
# inspect the best config
dispatch_context = autotvm.apply_history_best(record_file)
best_config = dispatch_context.query(task.target, task.workload)
print("\nBest config:")
print(best_config)
# apply history best from log file
with autotvm.apply_history_best(record_file):
with tvm.target.Target("cuda"):
s, arg_bufs = conv2d_no_batching(N, H, W, CO, CI, KH, KW, strides, padding)
func = tvm.build(s, arg_bufs)
# check correctness
a_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
w_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
c_np = conv2d_nchw_python(a_np, w_np, strides, padding)
dev = tvm.cuda()
a_tvm = tvm.nd.array(a_np, device=dev)
w_tvm = tvm.nd.array(w_np, device=dev)
c_tvm = tvm.nd.empty(c_np.shape, device=dev)
func(a_tvm, w_tvm, c_tvm)
tvm.testing.assert_allclose(c_np, c_tvm.numpy(), rtol=1e-2)
# Evaluate running time. Here we choose a large repeat number (400) to reduce the noise
# and the overhead of kernel launch. You can also use nvprof to validate the result.
evaluator = func.time_evaluator(func.entry_name, dev, number=400)
print("Time cost of this operator: %f" % evaluator(a_tvm, w_tvm, c_tvm).mean)
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Cannot find config for target=cuda -keys=cuda,gpu -arch=sm_75 -max_num_threads=1024 -thread_warp_size=32, workload=('tutorial/conv2d_no_batching', 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)). A fallback configuration is used, which may bring great performance regression.
Best config:
,None
Time cost of this operator: 0.037352
.. _sphx_glr_download_how_to_tune_with_autotvm_tune_conv2d_cuda.py:
.. only:: html
.. container:: sphx-glr-footer sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: tune_conv2d_cuda.py <tune_conv2d_cuda.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: tune_conv2d_cuda.ipynb <tune_conv2d_cuda.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_