blob: 2a1331f9f94bac92b46fb8fe86df21c50c9680d3 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tuning a single conv2d operator"""
from collections import namedtuple
import logging
import os
import tvm
from tvm import te
from tvm import autotvm
from tvm import topi
import vta
import vta.testing
env = vta.get_env()
Workload = namedtuple(
"Conv2DWorkload",
[
"batch",
"height",
"width",
"in_filter",
"out_filter",
"hkernel",
"wkernel",
"hpad",
"wpad",
"hstride",
"wstride",
],
)
resnet_wkls = [
# Workloads of resnet18 on imagenet
# ('resnet-18.C1', Workload(env.BATCH, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2)),
("resnet-18.C2", Workload(env.BATCH, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)),
("resnet-18.C3", Workload(env.BATCH, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)),
("resnet-18.C4", Workload(env.BATCH, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)),
("resnet-18.C5", Workload(env.BATCH, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)),
("resnet-18.C6", Workload(env.BATCH, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)),
("resnet-18.C7", Workload(env.BATCH, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)),
("resnet-18.C8", Workload(env.BATCH, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)),
("resnet-18.C9", Workload(env.BATCH, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)),
("resnet-18.C10", Workload(env.BATCH, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)),
("resnet-18.C11", Workload(env.BATCH, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)),
]
@tvm.te.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
"""Unlike topi's current clip, put min and max into two stages."""
const_min = tvm.tir.const(a_min, x.dtype)
const_max = tvm.tir.const(a_max, x.dtype)
x = te.compute(x.shape, lambda *i: tvm.te.min(x(*i), const_max), name="clipA")
x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
return x
def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation):
data_shape = (N // env.BATCH, CI // env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN)
kernel_shape = (CO // env.BLOCK_OUT, CI // env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN)
bias_shape = (N // env.BATCH, CO // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
bias = te.placeholder(bias_shape, name="bias", dtype=env.acc_dtype)
with tvm.target.vta():
res = topi.nn.conv2d(
input=data,
filter=kernel,
padding=padding,
strides=strides,
dilation=dilation,
layout="NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN),
out_dtype=env.acc_dtype,
)
res = topi.right_shift(res, env.WGT_WIDTH)
res = topi.add(res, bias)
res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
res = topi.cast(res, env.out_dtype)
if tvm.target.Target.current().device_name == "vta":
s = topi.generic.schedule_conv2d_nchw([res])
else:
s = te.create_schedule([res.op])
return s, [data, kernel, bias, res]
if __name__ == "__main__":
# Logging config (for printing tuning log to the screen)
logging.basicConfig()
# logging.getLogger('autotvm').setLevel(logging.DEBUG)
# Tuning log files
log_file = "%s.conv2d.log" % (env.TARGET)
# create tmp log file
tmp_log_file = log_file + ".tmp"
if os.path.exists(log_file):
os.remove(log_file)
# Get tracker info from env
tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
tracker_port = os.environ.get("TVM_TRACKER_PORT", None)
if not tracker_host or not tracker_port:
print("Set your AutoTVM tracker node host and port variables to run the autotuner")
exit()
for idx, (wl_name, wl) in enumerate(resnet_wkls):
prefix = "[Task %2d/%2d] " % (idx, len(resnet_wkls))
# Read in workload parameters
N = wl.batch
CI = wl.in_filter
H = wl.height
W = wl.width
CO = wl.out_filter
KH = wl.hkernel
KW = wl.wkernel
strides = (wl.hstride, wl.wstride)
padding = (wl.hpad, wl.wpad)
dilation = (1, 1)
# Create task
task = autotvm.task.create(
conv2d,
args=(N, CI, H, W, CO, KH, KW, strides, padding, dilation),
target=tvm.target.vta(),
target_host=env.target_host,
template_key="direct",
)
print(task.config_space)
# Tune
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.RPCRunner(
env.TARGET,
host=tracker_host,
port=int(tracker_port),
number=5,
timeout=60,
check_correctness=True,
),
)
# Run Tuner
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(
n_trial=len(task.config_space),
early_stopping=None,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file),
],
)
# Pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_file)
os.remove(tmp_log_file)