blob: 7274cf7a4cadc2bd86017d79037c6c353c86af15 [file] [log] [blame]
.. DO NOT EDIT. THIS FILE WAS AUTOMATICALLY GENERATED BY
.. TVM'S MONKEY-PATCHED VERSION OF SPHINX-GALLERY. TO MAKE
.. CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "how_to/tune_with_autoscheduler/tune_network_x86.py"
.. only:: html
.. note::
:class: sphx-glr-download-link-note
This tutorial can be used interactively with Google Colab! You can also click
:ref:`here <sphx_glr_download_how_to_tune_with_autoscheduler_tune_network_x86.py>` to run the Jupyter notebook locally.
.. image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg
:align: center
:target: https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/ad2a7f55d615d188ad664d56696815a6/tune_network_x86.ipynb
:width: 300px
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_how_to_tune_with_autoscheduler_tune_network_x86.py:
Auto-scheduling a Neural Network for x86 CPU
============================================
**Author**: `Lianmin Zheng <https://github.com/merrymercy>`_, `Chengfan Jia <https://github.com/jcf94/>`_
Auto-tuning for specific devices and workloads is critical for getting the
best performance. This is a tutorial on how to tune a whole neural
network for x86 CPU with the auto-scheduler.
To auto-tune a neural network, we partition the network into small subgraphs and
tune them independently. Each subgraph is treated as one search task.
A task scheduler slices the time and dynamically allocates time resources to
these tasks. The task scheduler predicts the impact of each task on the end-to-end
execution time and prioritizes the one that can reduce the execution time the most.
For each subgraph, we use the compute declaration in :code:`tvm/python/topi` to
get the computational DAG in the tensor expression form.
We then use the auto-scheduler to construct a search space of this DAG and search
for good schedules (low-level optimizations).
Different from the template-based :ref:`autotvm <tutorials-autotvm-sec>` which relies on
manual templates to define the search space, the auto-scheduler does not require any
schedule templates. In other words, the auto-scheduler only uses the compute declarations
in :code:`tvm/python/topi` and does not use existing schedule templates.
Note that this tutorial will not run on Windows or recent versions of macOS. To
get it to run, you will need to wrap the body of this tutorial in a :code:`if
__name__ == "__main__":` block.
.. GENERATED FROM PYTHON SOURCE LINES 47-58
.. code-block:: default
import sys
import numpy as np
import tvm
from tvm import relay, auto_scheduler
from tvm.relay import data_dep_optimization as ddo
import tvm.relay.testing
from tvm.contrib import graph_executor
.. GENERATED FROM PYTHON SOURCE LINES 59-71
Define a Network
----------------
First, we need to define the network with relay frontend API.
We can load some pre-defined network from :code:`tvm.relay.testing`.
We can also load models from MXNet, ONNX, PyTorch, and TensorFlow
(see :ref:`front end tutorials<tutorial-frontend>`).
For convolutional neural networks, although auto-scheduler can work correctly
with any layout, we found the best performance is typically achieved with NHWC layout.
We also implemented more optimizations for NHWC layout with the auto-scheduler.
So it is recommended to convert your models to NHWC layout to use the auto-scheduler.
You can use :ref:`ConvertLayout <convert-layout-usage>` pass to do the layout conversion in TVM.
.. GENERATED FROM PYTHON SOURCE LINES 71-146
.. code-block:: default
def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=False):
"""Get the symbol definition and random weight of a network"""
# auto-scheduler prefers NHWC layout
if layout == "NHWC":
image_shape = (224, 224, 3)
elif layout == "NCHW":
image_shape = (3, 224, 224)
else:
raise ValueError("Invalid layout: " + layout)
input_shape = (batch_size,) + image_shape
output_shape = (batch_size, 1000)
if name.startswith("resnet-"):
n_layer = int(name.split("-")[1])
mod, params = relay.testing.resnet.get_workload(
num_layers=n_layer,
batch_size=batch_size,
layout=layout,
dtype=dtype,
image_shape=image_shape,
)
elif name.startswith("resnet3d-"):
n_layer = int(name.split("-")[1])
mod, params = relay.testing.resnet.get_workload(
num_layers=n_layer,
batch_size=batch_size,
layout=layout,
dtype=dtype,
image_shape=image_shape,
)
elif name == "mobilenet":
mod, params = relay.testing.mobilenet.get_workload(
batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape
)
elif name == "squeezenet_v1.1":
assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout"
mod, params = relay.testing.squeezenet.get_workload(
version="1.1",
batch_size=batch_size,
dtype=dtype,
image_shape=image_shape,
)
elif name == "inception_v3":
input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else (batch_size, 299, 299, 3)
mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
elif name == "mlp":
mod, params = relay.testing.mlp.get_workload(
batch_size=batch_size, dtype=dtype, image_shape=image_shape, num_classes=1000
)
else:
raise ValueError("Network not found.")
if use_sparse:
from tvm.topi.sparse.utils import convert_model_dense_to_sparse
mod, params = convert_model_dense_to_sparse(mod, params, bs_r=4, random_params=True)
return mod, params, input_shape, output_shape
# Define the neural network and compilation target.
# If the target machine supports avx512 instructions, replace the
# "llvm -mcpu=core-avx2" with "llvm -mcpu=skylake-avx512"
network = "resnet-50"
use_sparse = False
batch_size = 1
layout = "NHWC"
target = tvm.target.Target("llvm -mcpu=core-avx2")
dtype = "float32"
log_file = "%s-%s-B%d-%s.json" % (network, layout, batch_size, target.kind.name)
.. GENERATED FROM PYTHON SOURCE LINES 147-156
Extract Search Tasks
--------------------
Next, we extract the search tasks and their weights from a network.
The weight of a task is the number of appearances of the task's subgraph
in the whole network.
By using the weight, we can approximate the end-to-end latency of the network
as :code:`sum(latency[t] * weight[t])`, where :code:`latency[t]` is the
latency of a task and :code:`weight[t]` is the weight of the task.
The task scheduler will just optimize this objective.
.. GENERATED FROM PYTHON SOURCE LINES 156-174
.. code-block:: default
# Extract tasks from the network
print("Get model...")
mod, params, input_shape, output_shape = get_network(
network,
batch_size,
layout,
dtype=dtype,
use_sparse=use_sparse,
)
print("Extract tasks...")
tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)
for idx, task in enumerate(tasks):
print("========== Task %d (workload key: %s) ==========" % (idx, task.workload_key))
print(task.compute_dag)
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Get model...
Extract tasks...
========== Task 0 (workload key: ["6d628209072e3e3dd8f49359935acea6", [1, 28, 28, 512], [1, 1, 512, 128], [1, 1, 1, 128], [1, 28, 28, 128]]) ==========
p0 = PLACEHOLDER [1, 28, 28, 512]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 512, 128]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 1, 1, 128]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 1 (workload key: ["3060808fc5c74e18b1276729071fbae0", [1, 56, 56, 64], [1, 1, 64, 256], [1, 56, 56, 256], [1, 56, 56, 256]]) ==========
p0 = PLACEHOLDER [1, 56, 56, 64]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 64, 256]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 56, 56, 256]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, ax1, ax2, ax3])
========== Task 2 (workload key: ["2d10de6646307f0e3e5cf4b31c20e69b", [1, 56, 56, 64], [1, 1, 64, 256], [1, 56, 56, 256]]) ==========
p0 = PLACEHOLDER [1, 56, 56, 64]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 64, 256]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
========== Task 3 (workload key: ["6d628209072e3e3dd8f49359935acea6", [1, 56, 56, 64], [1, 1, 64, 64], [1, 1, 1, 64], [1, 56, 56, 64]]) ==========
p0 = PLACEHOLDER [1, 56, 56, 64]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 64, 64]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 1, 1, 64]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 4 (workload key: ["08f7449d79e570b7274174709e5e5e01", [1, 2048], [1000, 2048], [1, 1000], [1, 1000]]) ==========
p0 = PLACEHOLDER [1, 2048]
p1 = PLACEHOLDER [1000, 2048]
T_matmul_NT(i0, i1) += (p0[i0, k]*p1[i1, k])
p2 = PLACEHOLDER [1, 1000]
T_add(ax0, ax1) = (T_matmul_NT[ax0, ax1] + p2[ax0, ax1])
========== Task 5 (workload key: ["76afb7bf408a1ffa0b8b7bc09d077dc3", [1, 56, 56, 64], [1, 1, 64, 256], [1, 56, 56, 256], [1, 1, 1, 256], [1, 56, 56, 256]]) ==========
p0 = PLACEHOLDER [1, 56, 56, 64]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 64, 256]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 56, 56, 256]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, ax1, ax2, ax3])
p3 = PLACEHOLDER [1, 1, 1, 256]
T_add(ax0, ax1, ax2, ax3) = (T_add[ax0, ax1, ax2, ax3] + p3[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 6 (workload key: ["8c53ca2904398da2889aa7508082d7bb", [1, 7, 7, 2048], [1, 1, 1, 2048]]) ==========
p0 = PLACEHOLDER [1, 7, 7, 2048]
adaptive_pool_sum(ax0, ax1, ax2, ax3) += p0[ax0, ((ax1*7) + rv0), ((ax2*7) + rv1), ax3]
adaptive_pool_avg(ax0, ax1, ax2, ax3) = (adaptive_pool_sum[ax0, ax1, ax2, ax3]/(float32((select((bool)1, ((ax1 + 1)*7), (((ax1 + 1)*7) + 1)) - (ax1*7)))*float32((select((bool)1, ((ax2 + 1)*7), (((ax2 + 1)*7) + 1)) - (ax2*7)))))
========== Task 7 (workload key: ["2beb39e9afe4c74822fffbcbb8533595", [1, 14, 14, 1024], [1, 1, 1024, 512], [1, 1, 1, 512], [1, 7, 7, 512]]) ==========
p0 = PLACEHOLDER [1, 14, 14, 1024]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 1024, 512]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, ((yy*2) + ry), ((xx*2) + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 1, 1, 512]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 8 (workload key: ["0fad1b42d0d33418e0a8d15d3bbad3c9", [1, 14, 14, 1024], [1, 1, 1024, 2048], [1, 7, 7, 2048]]) ==========
p0 = PLACEHOLDER [1, 14, 14, 1024]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 1024, 2048]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, ((yy*2) + ry), ((xx*2) + rx), rc]*p1[ry, rx, rc, ff])
========== Task 9 (workload key: ["3060808fc5c74e18b1276729071fbae0", [1, 7, 7, 512], [1, 1, 512, 2048], [1, 7, 7, 2048], [1, 7, 7, 2048]]) ==========
p0 = PLACEHOLDER [1, 7, 7, 512]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 512, 2048]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 7, 7, 2048]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, ax1, ax2, ax3])
========== Task 10 (workload key: ["2beb39e9afe4c74822fffbcbb8533595", [1, 56, 56, 256], [1, 1, 256, 128], [1, 1, 1, 128], [1, 28, 28, 128]]) ==========
p0 = PLACEHOLDER [1, 56, 56, 256]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 256, 128]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, ((yy*2) + ry), ((xx*2) + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 1, 1, 128]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 11 (workload key: ["76afb7bf408a1ffa0b8b7bc09d077dc3", [1, 28, 28, 128], [1, 1, 128, 512], [1, 28, 28, 512], [1, 1, 1, 512], [1, 28, 28, 512]]) ==========
p0 = PLACEHOLDER [1, 28, 28, 128]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 128, 512]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 28, 28, 512]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, ax1, ax2, ax3])
p3 = PLACEHOLDER [1, 1, 1, 512]
T_add(ax0, ax1, ax2, ax3) = (T_add[ax0, ax1, ax2, ax3] + p3[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 12 (workload key: ["0fad1b42d0d33418e0a8d15d3bbad3c9", [1, 56, 56, 256], [1, 1, 256, 512], [1, 28, 28, 512]]) ==========
p0 = PLACEHOLDER [1, 56, 56, 256]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 256, 512]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, ((yy*2) + ry), ((xx*2) + rx), rc]*p1[ry, rx, rc, ff])
========== Task 13 (workload key: ["7d79c516e212fe1d73f5dbb90eaca2cf", [1, 1000], [1, 1000]]) ==========
p0 = PLACEHOLDER [1, 1000]
T_softmax_maxelem(i0) max= p0[i0, k]
T_softmax_exp(i0, i1) = tir.exp((p0[i0, i1] - T_softmax_maxelem[i0]))
T_softmax_expsum(i0) += T_softmax_exp[i0, k]
T_softmax_norm(i0, i1) = (T_softmax_exp[i0, i1]/T_softmax_expsum[i0])
========== Task 14 (workload key: ["3060808fc5c74e18b1276729071fbae0", [1, 14, 14, 256], [1, 1, 256, 1024], [1, 14, 14, 1024], [1, 14, 14, 1024]]) ==========
p0 = PLACEHOLDER [1, 14, 14, 256]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 256, 1024]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 14, 14, 1024]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, ax1, ax2, ax3])
========== Task 15 (workload key: ["07f9fcad27bdd3233f86fe35a5185d33", [1, 224, 224, 3], [7, 7, 3, 64], [1, 1, 1, 64], [1, 112, 112, 64]]) ==========
p0 = PLACEHOLDER [1, 224, 224, 3]
pad_temp(i0, i1, i2, i3) = tir.if_then_else(((((i1 >= 3) && (i1 < 227)) && (i2 >= 3)) && (i2 < 227)), p0[i0, (i1 - 3), (i2 - 3), i3], 0f)
p1 = PLACEHOLDER [7, 7, 3, 64]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, ((yy*2) + ry), ((xx*2) + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 1, 1, 64]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 16 (workload key: ["6d012ba18a086c11ee2b85c7324e16f2", [1, 112, 112, 64], [1, 1, 1, 64], [1, 56, 56, 64]]) ==========
p0 = PLACEHOLDER [1, 112, 112, 64]
pad_temp(ax0, ax1, ax2, ax3) = tir.if_then_else(((((ax1 >= 1) && (ax1 < 113)) && (ax2 >= 1)) && (ax2 < 113)), p0[ax0, (ax1 - 1), (ax2 - 1), ax3], -3.40282e+38f)
pool_max(ax0, ax1, ax2, ax3) max= pad_temp[ax0, ((ax1*2) + rv0), ((ax2*2) + rv1), ax3]
p1 = PLACEHOLDER [1, 1, 1, 64]
T_add(ax0, ax1, ax2, ax3) = (pool_max[ax0, ax1, ax2, ax3] + p1[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 17 (workload key: ["6d628209072e3e3dd8f49359935acea6", [1, 7, 7, 2048], [1, 1, 2048, 512], [1, 1, 1, 512], [1, 7, 7, 512]]) ==========
p0 = PLACEHOLDER [1, 7, 7, 2048]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 2048, 512]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 1, 1, 512]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 18 (workload key: ["6d628209072e3e3dd8f49359935acea6", [1, 14, 14, 1024], [1, 1, 1024, 256], [1, 1, 1, 256], [1, 14, 14, 256]]) ==========
p0 = PLACEHOLDER [1, 14, 14, 1024]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 1024, 256]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 1, 1, 256]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 19 (workload key: ["38552500208b25b4035682b0e93cbce3", [1, 14, 14, 256], [6, 6, 256, 256], [1, 1, 1, 256], [1, 14, 14, 256]]) ==========
p0 = PLACEHOLDER [1, 14, 14, 256]
data_pad(i0, i1, i2, i3) = tir.if_then_else(((((i1 >= 1) && (i1 < 15)) && (i2 >= 1)) && (i2 < 15)), p0[i0, (i1 - 1), (i2 - 1), i3], 0f)
input_tile(eps, nu, p, ci) = data_pad[floordiv(p, 16), ((floormod(floordiv(p, 4), 4)*4) + eps), ((floormod(p, 4)*4) + nu), ci]
B(i, j) = select(((floormod(i, 6) == 5) && (floormod(j, 6) == 5)), 1f, select(((floormod(i, 6) == 5) && (floormod(j, 6) == 4)), ..(OMITTED).. (floormod(j, 6) == 1)), 0f, select(((floormod(i, 6) == 0) && (floormod(j, 6) == 0)), 1f, 0f))))))))))))))))))))))))))))))))))))
data_pack(eps, nu, p, ci) += ((input_tile[r_a, r_b, p, ci]*B[r_a, eps])*B[r_b, nu])
p1 = PLACEHOLDER [6, 6, 256, 256]
bgemm(eps, nu, p, co) += (data_pack[eps, nu, p, ci]*p1[eps, nu, co, ci])
A(i, j) = select(((floormod(i, 6) == 5) && (floormod(j, 4) == 3)), 1f, select(((floormod(i, 6) == 5) && (floormod(j, 4) == 2)), ..(OMITTED).. 6) == 0) && (floormod(j, 4) == 1)), 0f, select(((floormod(i, 6) == 0) && (floormod(j, 4) == 0)), 1f, 0f))))))))))))))))))))))))
inverse(vh, vw, p, co) += ((bgemm[r_a, r_b, p, co]*A[r_a, vh])*A[r_b, vw])
conv2d_winograd(n, h, w, co) = inverse[floormod(h, 4), floormod(w, 4), ((((n*4)*4) + (floordiv(h, 4)*4)) + floordiv(w, 4)), co]
p2 = PLACEHOLDER [1, 1, 1, 256]
T_add(ax0, ax1, ax2, ax3) = (conv2d_winograd[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 20 (workload key: ["6d628209072e3e3dd8f49359935acea6", [1, 56, 56, 256], [1, 1, 256, 64], [1, 1, 1, 64], [1, 56, 56, 64]]) ==========
p0 = PLACEHOLDER [1, 56, 56, 256]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 256, 64]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 1, 1, 64]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 21 (workload key: ["f07e228ef5f642b386d23a62df615e7b", [1, 7, 7, 512], [1, 1, 512, 2048], [1, 7, 7, 2048], [1, 1, 1, 2048], [1, 1, 1, 2048], [1, 7, 7, 2048]]) ==========
p0 = PLACEHOLDER [1, 7, 7, 512]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 512, 2048]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 7, 7, 2048]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, ax1, ax2, ax3])
p3 = PLACEHOLDER [1, 1, 1, 2048]
T_multiply(ax0, ax1, ax2, ax3) = (T_add[ax0, ax1, ax2, ax3]*p3[ax0, 0, 0, ax3])
p4 = PLACEHOLDER [1, 1, 1, 2048]
T_add(ax0, ax1, ax2, ax3) = (T_multiply[ax0, ax1, ax2, ax3] + p4[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 22 (workload key: ["3060808fc5c74e18b1276729071fbae0", [1, 28, 28, 128], [1, 1, 128, 512], [1, 28, 28, 512], [1, 28, 28, 512]]) ==========
p0 = PLACEHOLDER [1, 28, 28, 128]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 128, 512]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 28, 28, 512]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, ax1, ax2, ax3])
========== Task 23 (workload key: ["0fad1b42d0d33418e0a8d15d3bbad3c9", [1, 28, 28, 512], [1, 1, 512, 1024], [1, 14, 14, 1024]]) ==========
p0 = PLACEHOLDER [1, 28, 28, 512]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 512, 1024]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, ((yy*2) + ry), ((xx*2) + rx), rc]*p1[ry, rx, rc, ff])
========== Task 24 (workload key: ["2beb39e9afe4c74822fffbcbb8533595", [1, 28, 28, 512], [1, 1, 512, 256], [1, 1, 1, 256], [1, 14, 14, 256]]) ==========
p0 = PLACEHOLDER [1, 28, 28, 512]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 512, 256]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, ((yy*2) + ry), ((xx*2) + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 1, 1, 256]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 25 (workload key: ["76afb7bf408a1ffa0b8b7bc09d077dc3", [1, 14, 14, 256], [1, 1, 256, 1024], [1, 14, 14, 1024], [1, 1, 1, 1024], [1, 14, 14, 1024]]) ==========
p0 = PLACEHOLDER [1, 14, 14, 256]
pad_temp(i0, i1, i2, i3) = p0[i0, i1, i2, i3]
p1 = PLACEHOLDER [1, 1, 256, 1024]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 14, 14, 1024]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, ax1, ax2, ax3])
p3 = PLACEHOLDER [1, 1, 1, 1024]
T_add(ax0, ax1, ax2, ax3) = (T_add[ax0, ax1, ax2, ax3] + p3[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 26 (workload key: ["d37380659057397544e056461ea3bad3", [1, 56, 56, 64], [3, 3, 64, 64], [1, 1, 1, 64], [1, 56, 56, 64]]) ==========
p0 = PLACEHOLDER [1, 56, 56, 64]
pad_temp(i0, i1, i2, i3) = tir.if_then_else(((((i1 >= 1) && (i1 < 57)) && (i2 >= 1)) && (i2 < 57)), p0[i0, (i1 - 1), (i2 - 1), i3], 0f)
p1 = PLACEHOLDER [3, 3, 64, 64]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 1, 1, 64]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 27 (workload key: ["cfd09cf1ca9e943f0ee12a18813a5c75", [1, 28, 28, 128], [6, 6, 128, 128], [1, 1, 1, 128], [1, 28, 28, 128]]) ==========
p0 = PLACEHOLDER [1, 28, 28, 128]
data_pad(i0, i1, i2, i3) = tir.if_then_else(((((i1 >= 1) && (i1 < 29)) && (i2 >= 1)) && (i2 < 29)), p0[i0, (i1 - 1), (i2 - 1), i3], 0f)
input_tile(eps, nu, p, ci) = data_pad[floordiv(p, 49), ((floormod(floordiv(p, 7), 7)*4) + eps), ((floormod(p, 7)*4) + nu), ci]
B(i, j) = select(((floormod(i, 6) == 5) && (floormod(j, 6) == 5)), 1f, select(((floormod(i, 6) == 5) && (floormod(j, 6) == 4)), ..(OMITTED).. (floormod(j, 6) == 1)), 0f, select(((floormod(i, 6) == 0) && (floormod(j, 6) == 0)), 1f, 0f))))))))))))))))))))))))))))))))))))
data_pack(eps, nu, p, ci) += ((input_tile[r_a, r_b, p, ci]*B[r_a, eps])*B[r_b, nu])
p1 = PLACEHOLDER [6, 6, 128, 128]
bgemm(eps, nu, p, co) += (data_pack[eps, nu, p, ci]*p1[eps, nu, co, ci])
A(i, j) = select(((floormod(i, 6) == 5) && (floormod(j, 4) == 3)), 1f, select(((floormod(i, 6) == 5) && (floormod(j, 4) == 2)), ..(OMITTED).. 6) == 0) && (floormod(j, 4) == 1)), 0f, select(((floormod(i, 6) == 0) && (floormod(j, 4) == 0)), 1f, 0f))))))))))))))))))))))))
inverse(vh, vw, p, co) += ((bgemm[r_a, r_b, p, co]*A[r_a, vh])*A[r_b, vw])
conv2d_winograd(n, h, w, co) = inverse[floormod(h, 4), floormod(w, 4), ((((n*7)*7) + (floordiv(h, 4)*7)) + floordiv(w, 4)), co]
p2 = PLACEHOLDER [1, 1, 1, 128]
T_add(ax0, ax1, ax2, ax3) = (conv2d_winograd[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
========== Task 28 (workload key: ["d37380659057397544e056461ea3bad3", [1, 7, 7, 512], [3, 3, 512, 512], [1, 1, 1, 512], [1, 7, 7, 512]]) ==========
p0 = PLACEHOLDER [1, 7, 7, 512]
pad_temp(i0, i1, i2, i3) = tir.if_then_else(((((i1 >= 1) && (i1 < 8)) && (i2 >= 1)) && (i2 < 8)), p0[i0, (i1 - 1), (i2 - 1), i3], 0f)
p1 = PLACEHOLDER [3, 3, 512, 512]
conv2d_nhwc(nn, yy, xx, ff) += (pad_temp[nn, (yy + ry), (xx + rx), rc]*p1[ry, rx, rc, ff])
p2 = PLACEHOLDER [1, 1, 1, 512]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nhwc[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)
.. GENERATED FROM PYTHON SOURCE LINES 175-191
Begin Tuning
------------
Now, we set some options for tuning and launch the search tasks
* :code:`num_measure_trials` is the number of measurement trials we can use during the tuning.
You can set it to a small number (e.g., 200) for a fast demonstrative run.
In practice, we recommend setting it around :code:`800 * len(tasks)`,
which is typically enough for the search to converge.
For example, there are 29 tasks in resnet-50, so we can set it as 20000.
You can adjust this parameter according to your time budget.
* In addition, we use :code:`RecordToFile` to dump measurement records into a log file,
The measurement records can be used to query the history best, resume the search,
and do more analyses later.
* see :any:`auto_scheduler.TuningOptions`,
:any:`auto_scheduler.LocalRunner` for more parameters.
.. GENERATED FROM PYTHON SOURCE LINES 191-225
.. code-block:: default
def run_tuning():
print("Begin tuning...")
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=200, # change this to 20000 to achieve the best performance
runner=auto_scheduler.LocalRunner(repeat=10, enable_cpu_cache_flush=True),
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)
if use_sparse:
from tvm.topi.sparse.utils import sparse_sketch_rules
search_policy = [
auto_scheduler.SketchPolicy(
task,
program_cost_model=auto_scheduler.XGBModel(),
init_search_callbacks=sparse_sketch_rules(),
)
for task in tasks
]
tuner.tune(tune_option, search_policy=search_policy)
else:
tuner.tune(tune_option)
# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run it by yourself.
# run_tuning()
.. GENERATED FROM PYTHON SOURCE LINES 226-284
.. note:: Explain the printed information during tuning
During the tuning, a lot of information will be printed on the console.
They are used for debugging purposes. The most important info is the output
of the task scheduler. The following table is a sample output.
.. code-block:: c
----------------------------------------------------------------------
------------------------------ [ Task Scheduler ]
----------------------------------------------------------------------
| ID | Latency (ms) | Speed (GFLOPS) | Trials |
-------------------------------------------------
| 0 | 0.010 | 0.40 | 64 |
| 1 | 0.087 | 47.19 | 64 |
| 2 | 0.008 | -0.00 | 64 |
| 3 | 0.177 | 582.07 | 64 |
| 4 | 0.268 | 862.37 | 256 |
| 5 | 0.166 | 621.13 | 128 |
| 6 | 0.170 | 605.10 | 128 |
| 7 | 0.128 | 403.20 | 64 |
| 8 | 0.189 | 545.71 | 64 |
| 9 | 0.231 | 1001.01 | 448 |
| 10 | 0.155 | 664.80 | 256 |
| 11 | 0.155 | 662.86 | 256 |
| 12 | 0.119 | 434.08 | 64 |
| 13 | 0.199 | 522.13 | 64 |
| 14 | 0.235 | 986.56 | 320 |
| 15 | 0.149 | 689.13 | 128 |
| 16 | 0.155 | 664.80 | 192 |
| 17 | 0.151 | 340.64 | 64 |
| 18 | 0.176 | 597.55 | 128 |
| 19 | 0.220 | 1054.37 | 192 |
| 20 | 0.150 | 686.01 | 128 |
| 21 | 0.159 | 650.88 | 128 |
| 22 | 0.073 | 358.19 | 64 |
| 23 | 0.031 | 70.63 | 64 |
| 24 | 0.251 | 947.73 | 128 |
| 25 | 0.157 | 652.47 | 128 |
| 26 | 0.215 | 954.84 | 128 |
| 27 | 0.237 | 868.92 | 128 |
| 28 | 0.266 | 774.06 | 128 |
-------------------------------------------------
Estimated total latency: 10.016 ms Trials: 3992 Used time : 1131 s Next ID: 15
This table lists the latency and (estimated) speed of all tasks.
It also lists the allocation of measurement trials for all tasks.
The last line prints the total weighted latency of these tasks,
which can be a rough estimation of the end-to-end execution time
of the network.
The last line also prints the total number of measurement trials,
total time spent on auto-tuning and the id of the next task to tune.
There will also be some "tvm::Error"s errors, because the
auto-scheduler will try some invalid schedules.
You can safely ignore them if the tuning can continue, because these
errors are isolated from the main process.
.. GENERATED FROM PYTHON SOURCE LINES 286-292
.. note:: Terminate the tuning earlier
You can terminate the tuning earlier by forcibly killing this process.
As long as you get at least one valid schedule for each task in the log file,
you should be able to do the compilation (the secion below).
.. GENERATED FROM PYTHON SOURCE LINES 295-300
Compile and Evaluate
--------------------
After auto-tuning, we can compile the network with the best schedules we found.
All measurement records are dumped into the log file during auto-tuning,
so we can read the log file and load the best schedules.
.. GENERATED FROM PYTHON SOURCE LINES 300-318
.. code-block:: default
# Compile with the history best
print("Compile...")
with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}):
lib = relay.build(mod, target=target, params=params)
# Create graph executor
dev = tvm.device(str(target), 0)
module = graph_executor.GraphModule(lib["default"](dev))
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input("data", data_tvm)
# Evaluate
print("Evaluate inference time cost...")
print(module.benchmark(dev, repeat=3, min_repeat_ms=500))
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Compile...
Evaluate inference time cost...
Execution time summary:
mean (ms) median (ms) max (ms) min (ms) std (ms)
710.5261 710.5235 710.8594 710.1955 0.2710
.. GENERATED FROM PYTHON SOURCE LINES 319-335
Other Tips
----------
1. During the tuning, the auto-scheduler needs to compile many programs and
extract feature from them. This part is CPU-intensive,
so a high-performance CPU with many cores is recommended for faster search.
2. You can use :code:`python3 -m tvm.auto_scheduler.measure_record --mode distill -i log.json`
to distill the large log file and only save the best useful records.
3. You can resume a search from the previous log file. You just need to
add a new argument :code:`load_log_file` when creating the task scheduler
in function :code:`run_tuning`. Say,
:code:`tuner = auto_scheduler.TaskScheduler(tasks, task_weights, load_log_file=log_file)`
4. If you have multiple target CPUs, you can use all of them for measurements to
parallelize the measurements. Check this :ref:`section <tutorials-autotvm-scale-up-rpc-tracker>`
to learn how to use the RPC Tracker and RPC Server.
To use the RPC Tracker in auto-scheduler, replace the runner in :code:`TuningOptions`
with :any:`auto_scheduler.RPCRunner`.
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 38.598 seconds)
.. _sphx_glr_download_how_to_tune_with_autoscheduler_tune_network_x86.py:
.. only:: html
.. container:: sphx-glr-footer sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: tune_network_x86.py <tune_network_x86.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: tune_network_x86.ipynb <tune_network_x86.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_