| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """Perform ResNet autoTVM tuning on VTA using NNVM.""" |
| |
| import argparse |
| import os |
| import time |
| import numpy as np |
| |
| import tvm |
| from tvm import rpc, autotvm |
| from tvm.autotvm.measure.measure_methods import request_remote |
| from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner |
| from tvm.contrib import graph_runtime, util |
| from tvm.contrib.download import download |
| |
| import topi |
| import nnvm.compiler |
| import vta |
| import vta.testing |
| |
| env = vta.get_env() |
| |
| def register_vta_tuning_tasks(): |
| from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args |
| |
| @tvm.tag_scope(tag=topi.tag.ELEMWISE) |
| def my_clip(x, a_min, a_max): |
| """Unlike topi's current clip, put min and max into two stages.""" |
| const_min = tvm.const(a_min, x.dtype) |
| const_max = tvm.const(a_max, x.dtype) |
| x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") |
| x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") |
| return x |
| |
| # init autotvm env to register VTA operator |
| TaskExtractEnv() |
| |
| @autotvm.task.register("topi_nn_conv2d", override=True) |
| def _topi_nn_conv2d(*args, **kwargs): |
| assert not kwargs, "Do not support kwargs in template function call" |
| args = deserialize_args(args) |
| A, W = args[:2] |
| |
| with tvm.target.vta(): |
| res = topi.nn.conv2d(*args, **kwargs) |
| res = topi.right_shift(res, 8) |
| res = my_clip(res, 0, 127) |
| res = topi.cast(res, "int8") |
| |
| if tvm.target.current_target().device_name == 'vta': |
| s = topi.generic.schedule_conv2d_nchw([res]) |
| else: |
| s = tvm.create_schedule([res.op]) |
| return s, [A, W, res] |
| |
| |
| |
| def generate_graph(sym, params, target, target_host): |
| # Populate the shape and data type dictionary |
| shape_dict = {"data": (1, 3, 224, 224)} |
| dtype_dict = {"data": 'float32'} |
| shape_dict.update({k: v.shape for k, v in params.items()}) |
| dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) |
| |
| # Apply NNVM graph optimization passes |
| sym = vta.graph.clean_cast(sym) |
| sym = vta.graph.clean_conv_fuse(sym) |
| assert env.BLOCK_IN == env.BLOCK_OUT |
| sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT) |
| |
| # Compile NNVM graph |
| with nnvm.compiler.build_config(opt_level=3): |
| with vta.build_config(): |
| graph, lib, params = nnvm.compiler.build( |
| sym, target, shape_dict, dtype_dict, |
| params=params, target_host=target_host) |
| |
| return graph, lib, params |
| |
| |
| def extract_tasks(sym, params, target, target_host): |
| # Populate the shape and data type dictionary |
| shape_dict = {"data": (1, 3, 224, 224)} |
| dtype_dict = {"data": 'float32'} |
| shape_dict.update({k: v.shape for k, v in params.items()}) |
| dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) |
| |
| # Apply NNVM graph optimization passes |
| sym = vta.graph.clean_cast(sym) |
| sym = vta.graph.clean_conv_fuse(sym) |
| assert env.BLOCK_IN == env.BLOCK_OUT |
| sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT) |
| |
| with vta.build_config(): |
| tasks = autotvm.task.extract_from_graph(graph=sym, shape=shape_dict, dtype=dtype_dict, target=target, |
| params=params, symbols=(nnvm.sym.conv2d,), target_host=target_host) |
| return tasks |
| |
| |
| def download_model(): |
| url = "https://github.com/uwsaml/web-data/raw/master/vta/models/" |
| categ_fn = 'synset.txt' |
| graph_fn = 'resnet18_qt8.json' |
| params_fn = 'resnet18_qt8.params' |
| data_dir = '_data' |
| if not os.path.exists(data_dir): |
| os.makedirs(data_dir) |
| |
| for file in [categ_fn, graph_fn, params_fn]: |
| if not os.path.isfile(file): |
| download(os.path.join(url, file), os.path.join(data_dir, file)) |
| |
| sym = nnvm.graph.load_json(open(os.path.join(data_dir, graph_fn)).read()) |
| params = nnvm.compiler.load_param_dict(open(os.path.join(data_dir, params_fn), 'rb').read()) |
| |
| return sym, params |
| |
| |
| def tune_tasks(tasks, |
| measure_option, |
| tuner='xgb', |
| n_trial=1000, |
| early_stopping=None, |
| log_filename='tuning.log', |
| use_transfer_learning=True, |
| try_winograd=True): |
| # create tmp log file |
| tmp_log_file = log_filename + ".tmp" |
| if os.path.exists(tmp_log_file): |
| os.remove(tmp_log_file) |
| |
| for i, tsk in enumerate(reversed(tasks)): |
| prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) |
| |
| # create tuner |
| if tuner == 'xgb' or tuner == 'xgb-rank': |
| tuner_obj = XGBTuner(tsk, loss_type='rank') |
| elif tuner == 'ga': |
| tuner_obj = GATuner(tsk, pop_size=50) |
| elif tuner == 'random': |
| tuner_obj = RandomTuner(tsk) |
| elif tuner == 'gridsearch': |
| tuner_obj = GridSearchTuner(tsk) |
| else: |
| raise ValueError("Invalid tuner: " + tuner) |
| |
| if use_transfer_learning: |
| if os.path.isfile(tmp_log_file): |
| tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) |
| |
| # do tuning |
| n_trial_ = min(n_trial, len(tsk.config_space)) |
| tuner_obj.tune(n_trial_, |
| early_stopping=early_stopping, |
| measure_option=measure_option, |
| callbacks=[ |
| autotvm.callback.progress_bar(n_trial_, prefix=prefix), |
| autotvm.callback.log_to_file(tmp_log_file)]) |
| |
| # pick best records to a cache file |
| autotvm.record.pick_best(tmp_log_file, log_filename) |
| os.remove(tmp_log_file) |
| |
| if __name__ == '__main__': |
| |
| # Get tracker info from env |
| tracker_host = os.environ.get("TVM_TRACKER_HOST", None) |
| tracker_port = os.environ.get("TVM_TRACKER_PORT", None) |
| if not tracker_host or not tracker_port: |
| print("Set your AutoTVM tracker node host and port variables to run the autotuner") |
| exit() |
| |
| # Download model |
| sym, params = download_model() |
| |
| # Register VTA tuning tasks |
| register_vta_tuning_tasks() |
| |
| # Extract tasks |
| print("Extracting tasks...") |
| target = tvm.target.vta() |
| target_host = env.target_host |
| tasks = extract_tasks(sym, params, target, target_host) |
| |
| # Perform Autotuning |
| print("Tuning...") |
| tuning_opt = { |
| 'log_filename': 'resnet-18.log', |
| |
| 'tuner': 'random', |
| 'n_trial': 1e9, |
| 'early_stopping': None, |
| |
| 'measure_option': autotvm.measure_option( |
| builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func), |
| runner=autotvm.RPCRunner(env.TARGET, tracker_host, int(tracker_port), |
| number=4, repeat=3, timeout=60, |
| check_correctness=True)) |
| } |
| tune_tasks(tasks, **tuning_opt) |
| |
| # compile kernels with history best records |
| with autotvm.tophub.context(target, extra_files=[tuning_opt['log_filename']]): |
| |
| # ResNet parameters |
| input_shape = (1, 3, 224, 224) |
| dtype = 'float32'\ |
| |
| # Compile network |
| print("Compiling network with best tuning parameters...") |
| graph, lib, params = generate_graph(sym, params, target, target_host) |
| input_shape = (1, 3, 224, 224) |
| dtype = 'float32' |
| |
| # Export library |
| tmp = util.tempdir() |
| filename = "net.tar" |
| lib.export_library(tmp.relpath(filename)) |
| |
| # Upload module to device |
| print("Upload...") |
| remote = autotvm.measure.request_remote(env.TARGET, tracker_host, tracker_port, timeout=10000) |
| remote.upload(tmp.relpath(filename)) |
| rlib = remote.load_module(filename) |
| |
| # Upload parameters to device |
| ctx = remote.context(str(target), 0) |
| rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()} |
| data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype)) |
| module = graph_runtime.create(graph, rlib, ctx) |
| module.set_input('data', data_tvm) |
| module.set_input(**rparams) |
| |
| # Evaluate |
| print("Evaluate inference time cost...") |
| ftimer = module.module.time_evaluator("run", ctx, number=3, repeat=3) |
| prof_res = np.array(ftimer().results) * 1000 # convert to millisecond |
| print("Mean inference time (std dev): %.2f ms (%.2f ms)" % |
| (np.mean(prof_res), np.std(prof_res))) |
| |