| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """Perform ResNet autoTVM tuning on VTA using Relay.""" |
| |
| import argparse, os, time |
| from mxnet.gluon.model_zoo import vision |
| import numpy as np |
| from PIL import Image |
| |
| import topi |
| import tvm |
| from tvm import rpc, autotvm, relay |
| from tvm.autotvm.measure.measure_methods import request_remote |
| from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner |
| from tvm.contrib import graph_runtime, util, download |
| from tvm.contrib.debugger import debug_runtime |
| import vta |
| from vta.testing import simulator |
| from vta.top import graph_pack |
| from tvm.autotvm.task import extract_from_program |
| |
| def parse_arguments(): |
| |
| parser = argparse.ArgumentParser(description='Train a model for image classification.') |
| parser.add_argument('--model', type=str, default='resnet18_v1', choices=['resnet18_v1'], |
| help='Input model name.') |
| parser.add_argument('--start-name', type=str, default='nn.max_pool2d', |
| help='The name of the node where packing starts') |
| parser.add_argument('--stop-name', type=str, default='nn.global_avg_pool2d', |
| help='The name of the node where packing stops') |
| parser.add_argument('--debug-profile', action='store_true', |
| help='Show layer-wise time cost profiling results') |
| parser.add_argument('--device', default='vta', choices=['vta', 'arm_cpu'], |
| help='Select device target') |
| parser.add_argument('--measurements', type=int, default=1, |
| help='Number of measurements during AutoTVM search') |
| parser.add_argument('--tuner', type=str, default="random", |
| help='AutoTVM search strategy') |
| parser.add_argument('--log-filename', type=str, default="resnet-18.log", |
| help='AutoTVM log file name') |
| |
| return parser.parse_args() |
| |
| |
| def register_vta_tuning_tasks(): |
| from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args |
| |
| @tvm.tag_scope(tag=topi.tag.ELEMWISE) |
| def my_clip(x, a_min, a_max): |
| """Unlike topi's current clip, put min and max into two stages.""" |
| const_min = tvm.const(a_min, x.dtype) |
| const_max = tvm.const(a_max, x.dtype) |
| x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") |
| x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") |
| return x |
| |
| # init autotvm env to register VTA operator |
| TaskExtractEnv() |
| |
| @autotvm.task.register("topi_nn_conv2d", override=True) |
| def _topi_nn_conv2d(*args, **kwargs): |
| assert not kwargs, "Do not support kwargs in template function call" |
| args = deserialize_args(args) |
| A, W = args[:2] |
| |
| with tvm.target.vta(): |
| res = topi.nn.conv2d(*args, **kwargs) |
| res = topi.right_shift(res, 8) |
| res = my_clip(res, 0, 127) |
| res = topi.cast(res, "int8") |
| |
| if tvm.target.current_target().device_name == 'vta': |
| s = topi.generic.schedule_conv2d_nchw([res]) |
| else: |
| s = tvm.create_schedule([res.op]) |
| return s, [A, W, res] |
| |
| @autotvm.task.register("topi_nn_dense", override=True) |
| def _topi_nn_dense(*args, **kwargs): |
| assert not kwargs, "Do not support kwargs in template function call" |
| args = deserialize_args(args) |
| A, W = args[:2] |
| |
| with tvm.target.vta(): |
| res = topi.nn.dense(*args, **kwargs) |
| res = topi.right_shift(res, 8) |
| res = my_clip(res, 0, 127) |
| res = topi.cast(res, "int8") |
| |
| if tvm.target.current_target().device_name == 'vta': |
| s = topi.generic.schedule_dense([res]) |
| else: |
| s = tvm.create_schedule([res.op]) |
| |
| return s, [A, W, res] |
| |
| |
| def compile_network(opt, env, target): |
| |
| # Populate the shape and data type dictionary |
| dtype_dict = {"data": 'float32'} |
| shape_dict = {"data": (env.BATCH, 3, 224, 224)} |
| |
| # Get off the shelf gluon model, and convert to relay |
| gluon_model = vision.get_model(opt.model, pretrained=True) |
| mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict) |
| |
| # Update shape and type dictionary |
| shape_dict.update({k: v.shape for k, v in params.items()}) |
| dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) |
| |
| # Perform quantization in Relay |
| # Note: We set opt_level to 3 in order to fold batch norm |
| with relay.build_config(opt_level=3): |
| with relay.quantize.qconfig(global_scale=8.0, |
| skip_conv_layers=[0]): |
| relay_prog = relay.quantize.quantize(mod["main"], params=params) |
| |
| # Perform graph packing and constant folding for VTA target |
| if target.device_name == "vta": |
| assert env.BLOCK_IN == env.BLOCK_OUT |
| relay_prog = graph_pack( |
| relay_prog, |
| env.BATCH, |
| env.BLOCK_OUT, |
| env.WGT_WIDTH, |
| start_name=opt.start_name, |
| stop_name=opt.stop_name) |
| |
| return relay_prog, params |
| |
| |
| def tune_tasks(tasks, |
| measure_option, |
| tuner='xgb', |
| n_trial=1000, |
| early_stopping=None, |
| log_filename='tuning.log', |
| use_transfer_learning=True, |
| try_winograd=True): |
| |
| # create tmp log file |
| tmp_log_file = log_filename + ".tmp" |
| if os.path.exists(tmp_log_file): |
| os.remove(tmp_log_file) |
| |
| for i, tsk in enumerate(reversed(tasks)): |
| prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) |
| |
| # create tuner |
| if tuner == 'xgb' or tuner == 'xgb-rank': |
| tuner_obj = XGBTuner(tsk, loss_type='rank') |
| elif tuner == 'ga': |
| tuner_obj = GATuner(tsk, pop_size=50) |
| elif tuner == 'random': |
| tuner_obj = RandomTuner(tsk) |
| elif tuner == 'gridsearch': |
| tuner_obj = GridSearchTuner(tsk) |
| else: |
| raise ValueError("Invalid tuner: " + tuner) |
| |
| if use_transfer_learning: |
| if os.path.isfile(tmp_log_file): |
| tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) |
| |
| # do tuning |
| n_trial_ = min(n_trial, len(tsk.config_space)) |
| tuner_obj.tune(n_trial_, |
| early_stopping=early_stopping, |
| measure_option=measure_option, |
| callbacks=[ |
| autotvm.callback.progress_bar(n_trial_, prefix=prefix), |
| autotvm.callback.log_to_file(tmp_log_file)]) |
| |
| # pick best records to a cache file |
| autotvm.record.pick_best(tmp_log_file, log_filename) |
| os.remove(tmp_log_file) |
| |
| if __name__ == '__main__': |
| |
| opt = parse_arguments() |
| |
| # Make sure that TVM was compiled with RPC=1 |
| assert tvm.module.enabled("rpc") |
| |
| # Read in VTA environment |
| env = vta.get_env() |
| |
| # Get remote from fleet node |
| tracker_host = os.environ.get("TVM_TRACKER_HOST", None) |
| tracker_port = os.environ.get("TVM_TRACKER_PORT", None) |
| if not tracker_host or not tracker_port: |
| print("Set your AutoTVM tracker node host and port variables to run the autotuner") |
| exit() |
| |
| # Get remote |
| if env.TARGET != "sim": |
| |
| # Measure build start time |
| reconfig_start = time.time() |
| |
| # Get remote from fleet node |
| remote = autotvm.measure.request_remote(env.TARGET, tracker_host, int(tracker_port), timeout=10000) |
| |
| # Reconfigure the JIT runtime and FPGA. |
| # You can program the FPGA with your own custom bitstream |
| # by passing the path to the bitstream file instead of None. |
| vta.reconfig_runtime(remote) |
| vta.program_fpga(remote, bitstream=None) |
| |
| # Report on reconfiguration time |
| reconfig_time = time.time() - reconfig_start |
| print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time)) |
| |
| # In simulation mode, host the RPC server locally. |
| else: |
| remote = rpc.LocalSession() |
| |
| # VTA target and execution context |
| target = env.target if opt.device == "vta" else env.target_vta_cpu |
| ctx = remote.ext_dev(0) if opt.device == "vta" else remote.cpu(0) |
| |
| # Compile Relay program |
| print("Initial compile...") |
| relay_prog, params = compile_network(opt, env, target) |
| |
| # Register VTA tuning tasks |
| register_vta_tuning_tasks() |
| |
| # Perform task extraction on Relay program |
| print("Extracting tasks...") |
| tasks = extract_from_program(func=relay_prog, |
| params=params, |
| ops=(tvm.relay.op.nn.conv2d,), |
| target=target, |
| target_host=env.target_host) |
| |
| # Perform Autotuning |
| print("Tuning...") |
| tuning_opt = { |
| 'log_filename': opt.log_filename, |
| 'tuner': opt.tuner, |
| 'n_trial': 1e9, |
| 'early_stopping': None, |
| 'measure_option': autotvm.measure_option( |
| builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func), |
| runner=autotvm.RPCRunner(env.TARGET, tracker_host, tracker_port, |
| number=4, min_repeat_ms=150, repeat=opt.measurements, timeout=60, |
| check_correctness=True)) |
| } |
| tune_tasks(tasks, **tuning_opt) |
| |
| # Compile kernels with history best records |
| with autotvm.tophub.context(target, extra_files=[opt.log_filename]): |
| |
| # Compile network |
| print("Compiling network with best tuning parameters...") |
| with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): |
| if target.device_name != "vta": |
| graph, lib, params = relay.build( |
| relay_prog, target=target, |
| params=params, target_host=env.target_host) |
| else: |
| with vta.build_config(): |
| graph, lib, params = relay.build( |
| relay_prog, target=target, |
| params=params, target_host=env.target_host) |
| |
| # Export library |
| temp = util.tempdir() |
| lib.save(temp.relpath("graphlib.o")) |
| remote.upload(temp.relpath("graphlib.o")) |
| lib = remote.load_module("graphlib.o") |
| |
| # If detailed runtime info is needed build with debug runtime |
| if opt.debug_profile: |
| m = debug_runtime.create(graph, lib, ctx) |
| else: |
| m = graph_runtime.create(graph, lib, ctx) |
| |
| # Set the network parameters and synthetic input |
| image = tvm.nd.array( |
| (np.random.uniform(size=(1, 3, 224, 224))).astype('float32')) |
| m.set_input(**params) |
| m.set_input('data', image) |
| |
| # Perform inference |
| timer = m.module.time_evaluator("run", ctx, number=4, repeat=opt.measurements) |
| tcost = timer() |
| prof_res = np.array(tcost.results) * 1000 # convert to millisecond |
| print("Mean inference time (std dev): %.2f ms (%.2f ms)" % |
| (np.mean(prof_res), np.std(prof_res))) |
| |
| # Display profile information |
| if opt.debug_profile: |
| m.run() |