Add c++ and python local deploy example (#5)

* add deploy example

* fix pylint complain and add more information into readme.

* Test deploy in xilinx FPGA board and udpate ReadMe.md

* Fix python deploy bug.

* add synset support and fix error report

* add bitstream flash logic

* address review comments.

* add bitstream flash file

* change file name

* fix plint complain

* return once no input parameter
diff --git a/apps/deploy/Makefile b/apps/deploy/Makefile
new file mode 100644
index 0000000..777cb89
--- /dev/null
+++ b/apps/deploy/Makefile
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Makefile Example to deploy TVM modules.
+TVM_ROOT=${TVM_HOME}
+CUR_DIR=$(shell pwd)
+DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core
+
+TARGET := ${shell python3 ../../config/vta_config.py --target}
+
+
+VTA_LIB=vta
+ifeq (${TARGET}, sim)
+	VTA_LIB=vta_fsim
+endif
+
+PKG_CFLAGS = -std=c++11 -O0 -g -fPIC\
+						 -I${TVM_ROOT}/include\
+						 -I${TVM_ROOT}/vta/include\
+						 -I${DMLC_CORE}/include\
+						 -I${TVM_ROOT}/3rdparty/dlpack/include\
+						 -I${TVM_ROOT}/3rdparty/vta-hw/include\
+						 -I${TVM_ROOT}/\
+
+PKG_LDFLAGS = -L${TVM_ROOT}/build  -L${CUR_DIR} -ldl -pthread -l${VTA_LIB} -ltvm_runtime
+
+.PHONY: clean all
+
+all:./build/deploy copylib
+
+./build/deploy: ./build/deploy.o ./build/model/lib.so
+	$(CXX) $(PKG_CFLAGS) -o $@  $^ $(PKG_LDFLAGS)
+
+./build/deploy.o: cpp_deploy.cc
+	@mkdir -p $(@D)
+	$(CXX) -c $(PKG_CFLAGS) -o $@  $^
+
+./build/model/lib.so: ./build/model/lib.o
+	$(CXX) $(PKG_CFLAGS) -o $@  $^ $(PKG_LDFLAGS) -shared
+
+copylib: ${TVM_ROOT}/build/libtvm_runtime.so ${TVM_ROOT}/build/lib${VTA_LIB}.so 
+	@cp ${TVM_ROOT}/build/libtvm_runtime.so ./build
+	@cp ${TVM_ROOT}/build/lib${VTA_LIB}.so ./build
+
+clean:
+	rm -rf  ./build/*.o ./build/deploy
+
diff --git a/apps/deploy/README.md b/apps/deploy/README.md
new file mode 100644
index 0000000..173562f
--- /dev/null
+++ b/apps/deploy/README.md
@@ -0,0 +1,120 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+
+How to Deploy TVM-VTA Modules
+=============================
+This folder contains an example on how to deploy TVM-VTA modules.
+It also contains an example code to deploy with C++ and Python.
+
+1. In host machine tvm project enable vta fsim or FPGA and compile tvm successfully.
+
+2. In target FPGA machine, flash bitstream into FPGA, following are example on pynq board
+
+   'cd' into vta-hw/app/deploy, run following command,
+    "/home/xilinx/vta.bit" is the bitstream file
+
+	```bash
+        sudo python3 ./bitstream.py /home/xilinx/vta.bit
+	```
+
+3. Compile and Deploy with C++
+
+   3.1 Deploy with FPGA
+
+       3.1.1 in host machine change ./vta-hw/config/vta_config.json TARGET into FPGA type
+             for example "ultra96"
+
+       3.1.2 in host machine run resnet_export.py, this script would compile mxnet resnet18
+             into vta library, and compute graph, parameter and save into ./build/model folder.
+       
+	```bash
+  	python3 ./resnet_export.py
+	```
+
+       3.1.3 from host machine, copy './build/' folder(generate by #2) into target FPGA board folder 
+             "tvm/3rdparty/vta-hw/apps/deploy/"
+
+       3.1.4 in target FPGA board, enable FPGA in config file and run following command
+
+             to build libvta.so and libtvm_runtime.so
+       ```bash
+       make runtime vta
+       ```
+
+       3.1.5 in target FPGA board goto "tvm/3rdparty/vta-hw/apps/deploy/"
+       ```bash
+       cd tvm/3rdparty/vta-hw/apps/deploy/
+       ```
+
+       3.1.6 int FPGA board Run "make" command, the script would build "lib.so" and cop libtvm_runtime.so
+             and "libvta*.so" into "./build" folder and compile execute file "./deploy"
+      ```bash
+      make
+      ```
+  
+      3.1.7. in FPGA board use following command to convert a image into correct image size that match 
+             mxnet resnet18 requirement.
+      ```bash
+      ./img_data_help.py <image path>
+      ```
+      the said command would output a file name 'img_data'
+
+      3.1.8. in FPGA board run following command to get the image type
+      ```bash
+      ./deploy img_data
+      ```
+
+   3.2 Deploy with vta simulator(all steps happen in host machine)
+
+       3.2.1 change ./vta-hw/config/vta_config.json TARGET into "sim"
+
+       3.2.2 run resnet_export.py, this script would compile mxnet resnet18 into vta library, 
+            and compute graph, parameter and save into ./build/model folder.
+
+	```bash
+  	python3 ./resnet_export.py
+	```
+       
+       3.2.3 Run "make" command, the script would build "lib.so" and copy libtvm_runtime.so
+             and libvta*.so into ./build folder and compile execute file ./deploy
+       ```bash
+       make
+       ```
+
+4. Python deploy
+
+      4.1 Deploy with FPGA.
+
+          4.1.1 From host machine Copy "./vta-hw/apps/deploy/build" folder into 
+                target FPGA board "vta-hw/apps/deploy/" folder
+
+          4.1.2 on FPGA board build libtvmruntime.so and libvta.so
+
+          ```bash
+          make runtime vta
+          ```
+
+          4.1.3 in ./vta-hw/apps/deploy run make to compile ./build/model/lib.so
+          ```
+          make
+          ```
+
+          4.1.4 run python_deploy.py by "run_python_deploy.sh"
+          ```bash
+          sudo ./run_python_deploy.sh
+          ```
diff --git a/apps/deploy/bitstream.py b/apps/deploy/bitstream.py
new file mode 100644
index 0000000..6f646ef
--- /dev/null
+++ b/apps/deploy/bitstream.py
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Download bitstream into FPGA """
+
+import sys
+from pynq import Bitstream
+
+assert len(sys.argv) == 2, "usage: <bitstream path>"
+FILE = sys.argv[1]
+BITSTREAM = Bitstream(FILE)
+BITSTREAM.download()
diff --git a/apps/deploy/cpp_deploy.cc b/apps/deploy/cpp_deploy.cc
new file mode 100644
index 0000000..3f21368
--- /dev/null
+++ b/apps/deploy/cpp_deploy.cc
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <cstdio>
+#include <dlpack/dlpack.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/packed_func.h>
+#include <fstream>
+#include <iterator>
+#include <algorithm>
+#include <vta/runtime/runtime.h>
+
+
+void graph_test(std::string img,
+                std::string model_path,
+                std::string lib,
+                std::string graph,
+                std::string params) {
+  tvm::runtime::Module mod_dylib =
+        tvm::runtime::Module::LoadFromFile((model_path+lib).c_str()) ;
+  std::ifstream json_in((model_path + graph).c_str());
+  if(json_in.fail())
+  {
+    throw std::runtime_error("could not open json file");
+  }
+
+  std::ifstream params_in((model_path + params).c_str(), std::ios::binary);
+  if(params_in.fail())
+  {
+    throw std::runtime_error("could not open json file");
+  }
+
+  const std::string json_data((std::istreambuf_iterator<char>(json_in)),
+                               std::istreambuf_iterator<char>());
+  json_in.close();
+  const std::string params_data((std::istreambuf_iterator<char>(params_in)),
+                                 std::istreambuf_iterator<char>());
+  params_in.close();
+
+  TVMByteArray params_arr;
+  params_arr.data = params_data.c_str();
+  params_arr.size = params_data.length();
+
+  int dtype_code = kDLFloat;
+  int dtype_bits = 32;
+  int dtype_lanes = 1;
+  int device_type = kDLExtDev;
+  int device_id = 0;
+
+  // get global function module for graph runtime
+  tvm::runtime::Module mod = 
+    (*tvm::runtime::Registry::Get("tvm.graph_runtime.create"))(json_data,
+                                                                mod_dylib,
+                                                                device_type,
+                                                                device_id);
+  DLTensor* x;
+  tvm::runtime::PackedFunc get_input = mod.GetFunction("get_input");
+  x = get_input(0);
+  VTACommandHandle cmd;
+  char * vta_ptr = (char *)VTABufferCPUPtr(cmd, static_cast<void*>(x->data));
+
+  int in_ndim = 4;
+  int64_t in_shape[4] = {1, 3, 224, 224};
+  // load image data saved in binary
+  std::ifstream data_fin(img.c_str(), std::ios::binary);
+  data_fin.read(static_cast<char*>(vta_ptr), 3 * 224 * 224 * 4);
+  // get the function from the module(load parameters)
+  tvm::runtime::PackedFunc load_params = mod.GetFunction("load_params");
+  load_params(params_arr);
+  tvm::runtime::PackedFunc run = mod.GetFunction("run");
+  run();
+
+  DLTensor* y;
+  int out_ndim = 2;
+  int64_t out_shape[2] = {1, 1000};
+  TVMArrayAlloc(out_shape, out_ndim, dtype_code, dtype_bits, dtype_lanes,
+                  kDLCPU, device_id, &y);
+
+  // get the function from the module(get output data)
+  tvm::runtime::PackedFunc get_output = mod.GetFunction("get_output");
+  get_output(0, y);
+
+    // get the maximum position in output vector
+  auto y_iter = static_cast<float*>(y->data);
+  auto max_iter = std::max_element(y_iter, y_iter + 1000);
+  auto max_index = std::distance(y_iter, max_iter);
+  std::cout << "The maximum position in output vector is: " << max_index << std::endl;
+
+  TVMArrayFree(x);
+  TVMArrayFree(y);
+}
+
+int main(int argc, char *argv[]) {
+  if (argc <= 1) {
+  	printf("deploy <file name>\n");
+	return 0;
+  }
+  graph_test(argv[1],
+             "./model/",
+             "lib.so",
+             "graph.json",
+             "params.params");
+  return 0;
+
+}
+
diff --git a/apps/deploy/img_data_help.py b/apps/deploy/img_data_help.py
new file mode 100644
index 0000000..61452d6
--- /dev/null
+++ b/apps/deploy/img_data_help.py
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Resize Image To Match MXNET Model."""
+import sys
+import os
+from PIL import Image
+import numpy as np
+
+if __name__ == "__main__":
+    assert len(sys.argv) == 2, "usage: <image path>"
+    IMG_PATH = sys.argv[1]
+    assert os.path.isfile(IMG_PATH), "file " + IMG_PATH + "  not exist"
+    IMAGE = Image.open(IMG_PATH).resize((224, 224))
+    np.array(IMAGE).astype('float32').tofile("./img_data")
diff --git a/apps/deploy/python_deploy.py b/apps/deploy/python_deploy.py
new file mode 100644
index 0000000..dc6c37b
--- /dev/null
+++ b/apps/deploy/python_deploy.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Python VTA Deploy."""
+from __future__ import absolute_import, print_function
+
+import os
+from os.path import join
+from io import BytesIO
+from PIL import Image
+
+import requests
+import numpy as np
+
+import tvm
+from tvm.contrib import graph_runtime, download
+
+
+CTX = tvm.ext_dev(0)
+
+def load_vta_library():
+    """load vta lib"""
+    curr_path = os.path.dirname(
+        os.path.abspath(os.path.expanduser(__file__)))
+    proj_root = os.path.abspath(os.path.join(curr_path, "../../../../"))
+    vtadll = os.path.abspath(os.path.join(proj_root, "build/libvta.so"))
+    return tvm.runtime.load_module(vtadll)
+
+
+def load_model():
+    """ Load VTA Model  """
+
+    load_vta_library()
+
+    with open("./build/model/graph.json", "r") as graphfile:
+        graph = graphfile.read()
+
+    lib = tvm.runtime.load_module("./build/model/lib.so")
+
+    model = graph_runtime.create(graph, lib, CTX)
+
+    with open("./build/model/params.params", "rb") as paramfile:
+        param_bytes = paramfile.read()
+
+    categ_url = "https://github.com/uwsaml/web-data/raw/master/vta/models/"
+    categ_fn = "synset.txt"
+    download.download(join(categ_url, categ_fn), categ_fn)
+    synset = eval(open(categ_fn).read())
+
+    return model, param_bytes, synset
+
+if __name__ == "__main__":
+    MOD, PARAMS_BYTES, SYNSET = load_model()
+
+    IMAGE_URL = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
+    RESPONSE = requests.get(IMAGE_URL)
+
+    # Prepare test image for inference
+    IMAGE = Image.open(BytesIO(RESPONSE.content)).resize((224, 224))
+
+    MOD.set_input('data', IMAGE)
+    MOD.load_params(PARAMS_BYTES)
+    MOD.run()
+
+    TVM_OUTPUT = MOD.get_output(0, tvm.nd.empty((1, 1000), "float32", CTX))
+    TOP_CATEGORIES = np.argsort(TVM_OUTPUT.asnumpy()[0])
+    print("\t#1:", SYNSET[TOP_CATEGORIES[-1]])
diff --git a/apps/deploy/resnet_export.py b/apps/deploy/resnet_export.py
new file mode 100644
index 0000000..ad248bc
--- /dev/null
+++ b/apps/deploy/resnet_export.py
@@ -0,0 +1,132 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Compile And Export MXNET Resnet18 Model With VTA As Backend """
+from __future__ import absolute_import, print_function
+
+import os
+from os.path import exists
+import numpy as np
+from mxnet.gluon.model_zoo import vision
+
+import tvm
+from tvm import autotvm, relay
+from tvm.relay import op, transform
+
+import vta
+from vta.top import graph_pack
+from vta.top.graphpack import run_opt_pass
+
+# Load VTA parameters from the vta/config/vta_config.json file
+ENV = vta.get_env()
+assert ENV.target.device_name == "vta"
+# Dictionary lookup for when to start/end bit packing
+PACK_DICT = {"resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d", None, None],}
+
+# Name of Gluon model to compile
+MODEL = "resnet18_v1"
+assert MODEL in PACK_DICT
+
+def merge_transform_to_mxnet_model(mod):
+    """ Add Image Transform Logic Into Model """
+    svalue = np.array([123., 117., 104.])
+    sub_data = relay.Constant(tvm.nd.array(svalue)).astype("float32")
+    dvalue = np.array([58.395, 57.12, 57.37])
+    divide_data = relay.Constant(tvm.nd.array(dvalue)).astype("float32")
+
+    data_shape = (224, 224, 3)
+    data = relay.var("data", relay.TensorType(data_shape, "float32"))
+
+    simple_net = relay.expand_dims(data, axis=0, num_newaxis=1)
+    # To do, relay not support dynamic shape now, future need to add resize logic
+    # simple_net = relay.image.resize(simple_net, (224, 224), "NHWC", "bilinear", "align_corners")
+    simple_net = relay.subtract(simple_net, sub_data)
+    simple_net = relay.divide(simple_net, divide_data)
+    simple_net = relay.transpose(simple_net, ((0, 3, 1, 2)))
+
+    #merge tranform into pretrained model network
+    entry = mod["main"]
+    anf = run_opt_pass(entry.body, transform.ToANormalForm())
+    call = anf.value
+    data, weights = call.args
+    first_op = op.nn.conv2d(
+        simple_net,
+        weights,
+        strides=call.attrs.strides,
+        padding=call.attrs.padding,
+        dilation=call.attrs.dilation,
+        groups=call.attrs.groups,
+        channels=call.attrs.channels,
+        kernel_size=call.attrs.kernel_size,
+        out_dtype=call.attrs.out_dtype)
+    net = relay.expr.Let(anf.var, first_op, anf.body)
+    net = run_opt_pass(net, transform.ToGraphNormalForm())
+
+    mod['main'] = net
+    return mod
+
+def compile_mxnet_gulon_resnet(_env, _model):
+    """ Compile Model """
+    # Generate tvm IR from mxnet gluon model
+    # Populate the shape and data type dictionary for ImageNet classifier input
+    dtype_dict = {"data": 'float32'}
+    shape_dict = {"data": (_env.BATCH, 3, 224, 224)}
+    # Get off the shelf gluon model, and convert to relay
+    gluon_model = vision.get_model(_model, pretrained=True)
+    # Start front end compilation
+    mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict)
+    mod = merge_transform_to_mxnet_model(mod)
+    # Update shape and type dictionary
+    shape_dict.update({k: v.shape for k, v in params.items()})
+    dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
+
+    # Load pre-configured AutoTVM schedules
+    with autotvm.tophub.context(_env.target):
+        # Perform quantization in Relay
+        # Note: We set opt_level to 3 in order to fold batch norm
+        with relay.build_config(opt_level=3):
+            with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]):
+                mod = relay.quantize.quantize(mod, params=params)
+            # Perform graph packing and constant folding for VTA target
+            relay_prog = graph_pack(
+                mod["main"],
+                _env.BATCH,
+                _env.BLOCK_IN,
+                _env.WGT_WIDTH,
+                start_name=PACK_DICT[_model][0],
+                stop_name=PACK_DICT[_model][1])
+
+    # Compile Relay program with AlterOpLayout disabled
+    with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
+        with vta.build_config(debug_flag=0):
+            graph, lib, params = relay.build(
+                relay_prog, target=_env.target,
+                params=params, target_host=_env.target_host)
+
+    return graph, lib, params
+
+def export_tvm_compile(graph, lib, params, path):
+    """ Export Model"""
+    if not exists(path):
+        os.makedirs(path)
+    lib.save(path+"/lib.o")
+    with open(path+"/graph.json", "w") as graphfile:
+        graphfile.write(graph)
+    with open(path+"/params.params", "wb") as paramfile:
+        paramfile.write(relay.save_param_dict(params))
+
+GRAPH, LIB, PARAMS = compile_mxnet_gulon_resnet(ENV, MODEL)
+export_tvm_compile(GRAPH, LIB, PARAMS, "./build/model")
diff --git a/apps/deploy/run_python_deploy.sh b/apps/deploy/run_python_deploy.sh
new file mode 100644
index 0000000..4065d53
--- /dev/null
+++ b/apps/deploy/run_python_deploy.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+PROJROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/../../../../" && pwd )"
+
+export PYTHONPATH=${PYTHONPATH}:${PROJROOT}/python:${PROJROOT}/vta/python:./
+export PYTHONPATH=${PYTHONPATH}:/home/xilinx/pynq
+python3 -m python_deploy