blob: ec425ae5cc84f4fcfb7a47008129ca6241aa5e32 [file] [log] [blame]
.. DO NOT EDIT. THIS FILE WAS AUTOMATICALLY GENERATED BY
.. TVM'S MONKEY-PATCHED VERSION OF SPHINX-GALLERY. TO MAKE
.. CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "tutorial/auto_scheduler_matmul_x86.py"
.. only:: html
.. note::
:class: sphx-glr-download-link-note
This tutorial can be used interactively with Google Colab! You can also click
:ref:`here <sphx_glr_download_tutorial_auto_scheduler_matmul_x86.py>` to run the Jupyter notebook locally.
.. image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg
:align: center
:target: https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/246d4b8509474fd9046e69f6cc9b7f87/auto_scheduler_matmul_x86.ipynb
:width: 300px
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_tutorial_auto_scheduler_matmul_x86.py:
Optimizing Operators with Auto-scheduling
=========================================
**Author**: `Lianmin Zheng <https://github.com/merrymercy>`_, `Chengfan Jia <https://github.com/jcf94/>`_
In this tutorial, we will show how TVM's Auto Scheduling feature can find
optimal schedules without the need for writing a custom template.
Different from the template-based :doc:`AutoTVM <autotvm_matmul_x86>` which relies on
manual templates to define the search space, the auto-scheduler does not
require any templates. Users only need to write the computation declaration
without any schedule commands or templates. The auto-scheduler can
automatically generate a large search space and find a good schedule in the
space.
We use matrix multiplication as an example in this tutorial.
.. note::
Note that this tutorial will not run on Windows or recent versions of macOS. To
get it to run, you will need to wrap the body of this tutorial in a :code:`if
__name__ == "__main__":` block.
.. GENERATED FROM PYTHON SOURCE LINES 40-46
.. code-block:: default
import numpy as np
import tvm
from tvm import te, auto_scheduler
.. GENERATED FROM PYTHON SOURCE LINES 47-55
Defining the Matrix Multiplication
----------------------------------
To start, we define a matrix multiplication with a bias addition. Note that
this uses standard operations available in TVMs Tensor Expression language.
The major difference is the use of the :any:`register_workload` decorator at the top
of the function definition. The function should return a list of
input/output tensors. From these tensors, the auto-scheduler can get the
whole computational graph.
.. GENERATED FROM PYTHON SOURCE LINES 55-75
.. code-block:: default
@auto_scheduler.register_workload # Note the auto_scheduler decorator
def matmul_add(N, L, M, dtype):
A = te.placeholder((N, L), name="A", dtype=dtype)
B = te.placeholder((L, M), name="B", dtype=dtype)
C = te.placeholder((N, M), name="C", dtype=dtype)
k = te.reduce_axis((0, L), name="k")
matmul = te.compute(
(N, M),
lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
name="matmul",
attrs={"layout_free_placeholders": [B]}, # enable automatic layout transform for tensor B
)
out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")
return [A, B, C, out]
.. GENERATED FROM PYTHON SOURCE LINES 76-90
Create the search task
----------------------
With the function defined, we can now create the task for the auto_scheduler
to search against. We specify the particular parameters for this matrix
multiplication, in this case a multiplication of two square matrices of size
1024x1024. We then create a search task with N=L=M=1024 and dtype="float32"
.. admonition:: Improve performance with custom targets
In order for TVM to take full advantage of specific hardware platforms,
you will want to manually specify your CPU capabilities. For example:
- replace ``llvm`` below with ``llvm -mcpu=core-avx2`` to enable AVX2
- replace ``llvm`` below with ``llvm -mcpu=skylake-avx512`` to enable AVX-512
.. GENERATED FROM PYTHON SOURCE LINES 90-99
.. code-block:: default
target = tvm.target.Target("llvm")
N = L = M = 1024
task = tvm.auto_scheduler.SearchTask(func=matmul_add, args=(N, L, M, "float32"), target=target)
# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Computational DAG:
A = PLACEHOLDER [1024, 1024]
B = PLACEHOLDER [1024, 1024]
matmul(i, j) += (A[i, k]*B[k, j])
C = PLACEHOLDER [1024, 1024]
out(i, j) = (matmul[i, j] + C[i, j])
.. GENERATED FROM PYTHON SOURCE LINES 100-112
Set Parameters for Auto-Scheduler
---------------------------------
Next, we set parameters for the auto-scheduler.
* :code:`num_measure_trials` is the number of measurement trials we can use
during the search. We only make 10 trials in this tutorial for a fast
demonstration. In practice, 1000 is a good value for the search to converge.
You can do more trials according to your time budget.
* In addition, we use :any:`RecordToFile <auto_scheduler.RecordToFile>` to log measurement records into a
file ``matmul.json``. The measurement records can be used to query the history
best, resume the search, and do more analyses later.
* see :any:`TuningOptions <auto_scheduler.TuningOptions>` for more parameters
.. GENERATED FROM PYTHON SOURCE LINES 112-120
.. code-block:: default
log_file = "matmul.json"
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=10,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
verbose=2,
)
.. GENERATED FROM PYTHON SOURCE LINES 121-126
Run the search
--------------
Now we get all inputs ready. Pretty simple, isn't it? We can kick off the
search and let the auto-scheduler do its magic. After some measurement
trials, we can load the best schedule from the log file and apply it.
.. GENERATED FROM PYTHON SOURCE LINES 126-132
.. code-block:: default
# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)
.. GENERATED FROM PYTHON SOURCE LINES 133-139
Inspecting the Optimized Schedule
---------------------------------
We can lower the schedule to see the IR after auto-scheduling. The
auto-scheduler correctly performs optimizations including multi-level tiling,
layout transformation, parallelization, vectorization, unrolling, and
operator fusion.
.. GENERATED FROM PYTHON SOURCE LINES 139-143
.. code-block:: default
print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Lowered TIR:
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32"), out: T.Buffer((1024, 1024), "float32")):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
auto_scheduler_layout_transform = T.allocate([1048576], "float32", "global")
auto_scheduler_layout_transform_1 = T.Buffer((1048576,), data=auto_scheduler_layout_transform)
for ax0_ax1_fused_ax2_fused in T.parallel(128):
for ax4, ax6, ax7 in T.grid(256, 4, 8):
B_1 = T.Buffer((1048576,), data=B.data)
auto_scheduler_layout_transform_1[ax0_ax1_fused_ax2_fused * 8192 + ax4 * 32 + ax6 * 8 + ax7] = B_1[ax4 * 4096 + ax6 * 1024 + ax0_ax1_fused_ax2_fused * 8 + ax7]
for i_outer_outer_j_outer_outer_fused in T.parallel(16384):
matmul = T.allocate([4], "float32x8", "global")
for i_outer_inner in range(2):
matmul_1 = T.Buffer((4,), "float32x8", data=matmul)
matmul_1[0] = T.Broadcast(T.float32(0), 8)
matmul_1[1] = T.Broadcast(T.float32(0), 8)
matmul_1[2] = T.Broadcast(T.float32(0), 8)
matmul_1[3] = T.Broadcast(T.float32(0), 8)
for k_outer, k_inner in T.grid(256, 4):
cse_var_2: T.int32 = i_outer_outer_j_outer_outer_fused % 128 * 8192 + k_outer * 32 + k_inner * 8
cse_var_1: T.int32 = i_outer_outer_j_outer_outer_fused // 128 * 8192 + i_outer_inner * 4096 + k_outer * 4 + k_inner
A_1 = T.Buffer((1048576,), data=A.data)
matmul_1[0] = matmul_1[0] + T.Broadcast(A_1[cse_var_1], 8) * auto_scheduler_layout_transform_1[cse_var_2:cse_var_2 + 8]
matmul_1[1] = matmul_1[1] + T.Broadcast(A_1[cse_var_1 + 1024], 8) * auto_scheduler_layout_transform_1[cse_var_2:cse_var_2 + 8]
matmul_1[2] = matmul_1[2] + T.Broadcast(A_1[cse_var_1 + 2048], 8) * auto_scheduler_layout_transform_1[cse_var_2:cse_var_2 + 8]
matmul_1[3] = matmul_1[3] + T.Broadcast(A_1[cse_var_1 + 3072], 8) * auto_scheduler_layout_transform_1[cse_var_2:cse_var_2 + 8]
for i_inner in range(4):
cse_var_3: T.int32 = i_outer_outer_j_outer_outer_fused // 128 * 8192 + i_outer_inner * 4096 + i_inner * 1024 + i_outer_outer_j_outer_outer_fused % 128 * 8
out_1 = T.Buffer((1048576,), data=out.data)
C_1 = T.Buffer((1048576,), data=C.data)
out_1[cse_var_3:cse_var_3 + 8] = matmul_1[i_inner] + C_1[cse_var_3:cse_var_3 + 8]
.. GENERATED FROM PYTHON SOURCE LINES 144-147
Check correctness and evaluate performance
------------------------------------------
We build the binary and check its correctness and performance.
.. GENERATED FROM PYTHON SOURCE LINES 147-172
.. code-block:: default
func = tvm.build(sch, args, target)
a_np = np.random.uniform(size=(N, L)).astype(np.float32)
b_np = np.random.uniform(size=(L, M)).astype(np.float32)
c_np = np.random.uniform(size=(N, M)).astype(np.float32)
out_np = a_np.dot(b_np) + c_np
dev = tvm.cpu()
a_tvm = tvm.nd.array(a_np, device=dev)
b_tvm = tvm.nd.array(b_np, device=dev)
c_tvm = tvm.nd.array(c_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)
func(a_tvm, b_tvm, c_tvm, out_tvm)
# Check results
np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)
# Evaluate execution time.
evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
print(
"Execution time of this operator: %.3f ms"
% (np.median(evaluator(a_tvm, b_tvm, c_tvm, out_tvm).results) * 1000)
)
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Execution time of this operator: 98.221 ms
.. GENERATED FROM PYTHON SOURCE LINES 173-182
Using the record file
---------------------
During the search, all measurement records are logged into the record file
``matmul.json```. The measurement records can be used to re-apply search
results, resume the search, and perform other analyses.
Here is an example where we load the best schedule from a file, and print the
equivalent python schedule API. This can be used for debugging and learning
the behavior of the auto-scheduler.
.. GENERATED FROM PYTHON SOURCE LINES 182-186
.. code-block:: default
print("Equivalent python schedule:")
print(task.print_best(log_file))
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Equivalent python schedule:
matmul_i, matmul_j, matmul_k = tuple(matmul.op.axis) + tuple(matmul.op.reduce_axis)
out_i, out_j = tuple(out.op.axis) + tuple(out.op.reduce_axis)
matmul_i_o_i, matmul_i_i = s[matmul].split(matmul_i, factor=4)
matmul_i_o_o_i, matmul_i_o_i = s[matmul].split(matmul_i_o_i, factor=1)
matmul_i_o_o_o, matmul_i_o_o_i = s[matmul].split(matmul_i_o_o_i, factor=2)
matmul_j_o_i, matmul_j_i = s[matmul].split(matmul_j, factor=8)
matmul_j_o_o_i, matmul_j_o_i = s[matmul].split(matmul_j_o_i, factor=1)
matmul_j_o_o_o, matmul_j_o_o_i = s[matmul].split(matmul_j_o_o_i, factor=1)
matmul_k_o, matmul_k_i = s[matmul].split(matmul_k, factor=4)
s[matmul].reorder(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i, matmul_j_o_o_i, matmul_k_o, matmul_i_o_i, matmul_j_o_i, matmul_k_i, matmul_i_i, matmul_j_i)
out_i_o_i, out_i_i = s[out].split(out_i, factor=4)
out_i_o_o, out_i_o_i = s[out].split(out_i_o_i, factor=2)
out_j_o_i, out_j_i = s[out].split(out_j, factor=8)
out_j_o_o, out_j_o_i = s[out].split(out_j_o_i, factor=1)
s[out].reorder(out_i_o_o, out_j_o_o, out_i_o_i, out_j_o_i, out_i_i, out_j_i)
s[matmul].compute_at(s[out], out_j_o_i)
out_i_o_o_j_o_o_fused = s[out].fuse(out_i_o_o, out_j_o_o)
s[out].parallel(out_i_o_o_j_o_o_fused)
s[matmul].pragma(matmul_i_o_o_o, "auto_unroll_max_step", 8)
s[matmul].pragma(matmul_i_o_o_o, "unroll_explicit", True)
s[matmul].vectorize(matmul_j_i)
s[out].vectorize(out_j_i)
.. GENERATED FROM PYTHON SOURCE LINES 187-191
A more complicated example is to resume the search. In this case, we need to
create the search policy and cost model by ourselves and resume the status of
search policy and cost model with the log file. In the example below we
resume the status and do more 5 trials.
.. GENERATED FROM PYTHON SOURCE LINES 191-208
.. code-block:: default
def resume_search(task, log_file):
print("Resume search:")
cost_model = auto_scheduler.XGBModel()
cost_model.update_from_file(log_file)
search_policy = auto_scheduler.SketchPolicy(
task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]
)
task.tune(tune_option, search_policy=search_policy)
resume_search(task, log_file)
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Resume search:
*E
.. GENERATED FROM PYTHON SOURCE LINES 209-216
Final Notes and Summary
-----------------------
In this tutorial, we have shown how to use the TVM Auto-Scheduler to
automatically optimize a matrix multiplication, without the need to specify a
search template. It ends a series of examples that starts from the Tensor
Expression (TE) language that demonstrates how TVM can optimize computational
operations.
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 45.100 seconds)
.. _sphx_glr_download_tutorial_auto_scheduler_matmul_x86.py:
.. only:: html
.. container:: sphx-glr-footer sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: auto_scheduler_matmul_x86.py <auto_scheduler_matmul_x86.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: auto_scheduler_matmul_x86.ipynb <auto_scheduler_matmul_x86.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_