| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """ |
| Compile PyTorch Models |
| ====================== |
| **Author**: `Alex Wong <https://github.com/alexwong/>`_ |
| |
| This article is an introductory tutorial to deploy PyTorch models with Relay. |
| |
| For us to begin with, PyTorch should be installed. |
| TorchVision is also required since we will be using it as our model zoo. |
| |
| A quick solution is to install via pip |
| |
| .. code-block:: bash |
| |
| pip install torch==1.4.0 |
| pip install torchvision==0.5.0 |
| |
| or please refer to official site |
| https://pytorch.org/get-started/locally/ |
| |
| PyTorch versions should be backwards compatible but should be used |
| with the proper TorchVision version. |
| |
| Currently, TVM supports PyTorch 1.4 and 1.3. Other versions may |
| be unstable. |
| """ |
| |
| import tvm |
| from tvm import relay |
| |
| import numpy as np |
| |
| from tvm.contrib.download import download_testdata |
| |
| # PyTorch imports |
| import torch |
| import torchvision |
| |
| ###################################################################### |
| # Load a pretrained PyTorch model |
| # ------------------------------- |
| model_name = "resnet18" |
| model = getattr(torchvision.models, model_name)(pretrained=True) |
| model = model.eval() |
| |
| # We grab the TorchScripted model via tracing |
| input_shape = [1, 3, 224, 224] |
| input_data = torch.randn(input_shape) |
| scripted_model = torch.jit.trace(model, input_data).eval() |
| |
| ###################################################################### |
| # Load a test image |
| # ----------------- |
| # Classic cat example! |
| from PIL import Image |
| |
| img_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true" |
| img_path = download_testdata(img_url, "cat.png", module="data") |
| img = Image.open(img_path).resize((224, 224)) |
| |
| # Preprocess the image and convert to tensor |
| from torchvision import transforms |
| |
| my_preprocess = transforms.Compose( |
| [ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ] |
| ) |
| img = my_preprocess(img) |
| img = np.expand_dims(img, 0) |
| |
| ###################################################################### |
| # Import the graph to Relay |
| # ------------------------- |
| # Convert PyTorch graph to Relay graph. The input name can be arbitrary. |
| input_name = "input0" |
| shape_list = [(input_name, img.shape)] |
| mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) |
| |
| ###################################################################### |
| # Relay Build |
| # ----------- |
| # Compile the graph to llvm target with given input specification. |
| target = "llvm" |
| target_host = "llvm" |
| ctx = tvm.cpu(0) |
| with tvm.transform.PassContext(opt_level=3): |
| lib = relay.build(mod, target=target, target_host=target_host, params=params) |
| |
| ###################################################################### |
| # Execute the portable graph on TVM |
| # --------------------------------- |
| # Now we can try deploying the compiled model on target. |
| from tvm.contrib import graph_runtime |
| |
| dtype = "float32" |
| m = graph_runtime.GraphModule(lib["default"](ctx)) |
| # Set inputs |
| m.set_input(input_name, tvm.nd.array(img.astype(dtype))) |
| # Execute |
| m.run() |
| # Get outputs |
| tvm_output = m.get_output(0) |
| |
| ##################################################################### |
| # Look up synset name |
| # ------------------- |
| # Look up prediction top 1 index in 1000 class synset. |
| synset_url = "".join( |
| [ |
| "https://raw.githubusercontent.com/Cadene/", |
| "pretrained-models.pytorch/master/data/", |
| "imagenet_synsets.txt", |
| ] |
| ) |
| synset_name = "imagenet_synsets.txt" |
| synset_path = download_testdata(synset_url, synset_name, module="data") |
| with open(synset_path) as f: |
| synsets = f.readlines() |
| |
| synsets = [x.strip() for x in synsets] |
| splits = [line.split(" ") for line in synsets] |
| key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits} |
| |
| class_url = "".join( |
| [ |
| "https://raw.githubusercontent.com/Cadene/", |
| "pretrained-models.pytorch/master/data/", |
| "imagenet_classes.txt", |
| ] |
| ) |
| class_name = "imagenet_classes.txt" |
| class_path = download_testdata(class_url, class_name, module="data") |
| with open(class_path) as f: |
| class_id_to_key = f.readlines() |
| |
| class_id_to_key = [x.strip() for x in class_id_to_key] |
| |
| # Get top-1 result for TVM |
| top1_tvm = np.argmax(tvm_output.asnumpy()[0]) |
| tvm_class_key = class_id_to_key[top1_tvm] |
| |
| # Convert input to PyTorch variable and get PyTorch result for comparison |
| with torch.no_grad(): |
| torch_img = torch.from_numpy(img) |
| output = model(torch_img) |
| |
| # Get top-1 result for PyTorch |
| top1_torch = np.argmax(output.numpy()) |
| torch_class_key = class_id_to_key[top1_torch] |
| |
| print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key])) |
| print("Torch top-1 id: {}, class name: {}".format(top1_torch, key_to_classname[torch_class_key])) |