| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import tvm |
| from tvm import rpc, relay |
| from tvm.contrib.download import download_testdata |
| from tvm.relay.expr_functor import ExprMutator |
| from tvm.relay import transform |
| from tvm.relay.op.annotation import compiler_begin, compiler_end |
| from tvm.relay.quantize.quantize import prerequisite_optimize |
| from tvm.contrib import utils, xcode, graph_executor, coreml_runtime |
| from tvm.contrib.target import coreml as _coreml |
| |
| import os |
| import re |
| import sys |
| import numpy as np |
| from mxnet import gluon |
| from PIL import Image |
| import coremltools |
| import argparse |
| |
| # Change target configuration, this is setting for iphone6s |
| # arch = "x86_64" |
| # sdk = "iphonesimulator" |
| arch = "arm64" |
| sdk = "iphoneos" |
| target_host = "llvm -mtriple=%s-apple-darwin" % arch |
| |
| MODES = {"proxy": rpc.connect, "tracker": rpc.connect_tracker, "standalone": rpc.connect} |
| |
| # override metal compiler to compile to iphone |
| @tvm.register_func("tvm_callback_metal_compile") |
| def compile_metal(src): |
| return xcode.compile_metal(src, sdk=sdk) |
| |
| |
| def prepare_input(): |
| img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true" |
| img_name = "cat.png" |
| synset_url = "".join( |
| [ |
| "https://gist.githubusercontent.com/zhreshold/", |
| "4d0b62f3d01426887599d4f7ede23ee5/raw/", |
| "596b27d23537e5a1b5751d2b0481ef172f58b539/", |
| "imagenet1000_clsid_to_human.txt", |
| ] |
| ) |
| synset_name = "imagenet1000_clsid_to_human.txt" |
| img_path = download_testdata(img_url, "cat.png", module="data") |
| synset_path = download_testdata(synset_url, synset_name, module="data") |
| with open(synset_path) as f: |
| synset = eval(f.read()) |
| image = Image.open(img_path).resize((224, 224)) |
| |
| image = np.array(image) - np.array([123.0, 117.0, 104.0]) |
| image /= np.array([58.395, 57.12, 57.375]) |
| image = image.transpose((2, 0, 1)) |
| image = image[np.newaxis, :] |
| return image.astype("float32"), synset |
| |
| |
| def get_model(model_name, data_shape): |
| gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True) |
| mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) |
| # we want a probability so add a softmax operator |
| func = mod["main"] |
| func = relay.Function( |
| func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs |
| ) |
| |
| return func, params |
| |
| |
| def test_mobilenet(host, port, key, mode): |
| temp = utils.tempdir() |
| image, synset = prepare_input() |
| model, params = get_model("mobilenetv2_1.0", image.shape) |
| |
| def run(mod, target): |
| with relay.build_config(opt_level=3): |
| lib = relay.build( |
| mod, target=tvm.target.Target(target, host=target_host), params=params |
| ) |
| path_dso = temp.relpath("deploy.dylib") |
| lib.export_library(path_dso, xcode.create_dylib, arch=arch, sdk=sdk) |
| |
| # connect to the proxy |
| if mode == "tracker": |
| remote = MODES[mode](host, port).request(key) |
| else: |
| remote = MODES[mode](host, port, key=key) |
| remote.upload(path_dso) |
| |
| if target == "metal": |
| dev = remote.metal(0) |
| else: |
| dev = remote.cpu(0) |
| lib = remote.load_module("deploy.dylib") |
| m = graph_executor.GraphModule(lib["default"](dev)) |
| |
| m.set_input("data", tvm.nd.array(image, dev)) |
| m.run() |
| tvm_output = m.get_output(0) |
| top1 = np.argmax(tvm_output.numpy()[0]) |
| print("TVM prediction top-1:", top1, synset[top1]) |
| |
| # evaluate |
| ftimer = m.module.time_evaluator("run", dev, number=3, repeat=10) |
| prof_res = np.array(ftimer().results) * 1000 |
| print("%-19s (%s)" % ("%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res))) |
| |
| def annotate(func, compiler): |
| """ |
| An annotator for Core ML. |
| """ |
| # Bind free variables to the constant values. |
| bind_dict = {} |
| for arg in func.params: |
| name = arg.name_hint |
| if name in params: |
| bind_dict[arg] = relay.const(params[name]) |
| |
| func = relay.bind(func, bind_dict) |
| |
| # Annotate the entire graph for Core ML |
| mod = tvm.IRModule() |
| mod["main"] = func |
| |
| seq = tvm.transform.Sequential( |
| [ |
| transform.SimplifyInference(), |
| transform.FoldConstant(), |
| transform.FoldScaleAxis(), |
| transform.AnnotateTarget(compiler), |
| transform.MergeCompilerRegions(), |
| transform.PartitionGraph(), |
| ] |
| ) |
| |
| with relay.build_config(opt_level=3): |
| mod = seq(mod) |
| |
| return mod |
| |
| # CPU |
| run(model, target_host) |
| # Metal |
| run(model, "metal") |
| # CoreML |
| run(annotate(model, "coremlcompiler"), target_host) |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Demo app demonstrates how ios_rpc works.") |
| parser.add_argument("--host", required=True, type=str, help="Adress of rpc server") |
| parser.add_argument("--port", type=int, default=9090, help="rpc port (default: 9090)") |
| parser.add_argument("--key", type=str, default="iphone", help="device key (default: iphone)") |
| parser.add_argument( |
| "--mode", |
| type=str, |
| default="tracker", |
| help="type of RPC connection (default: tracker), possible values: {}".format( |
| ", ".join(MODES.keys()) |
| ), |
| ) |
| |
| args = parser.parse_args() |
| assert args.mode in MODES.keys() |
| test_mobilenet(args.host, args.port, args.key, args.mode) |