| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """Defines a top-level glue class that operates the Transport and Flasher classes.""" |
| |
| import logging |
| import time |
| |
| from .._ffi import get_global_func |
| from ..contrib import graph_runtime |
| from ..rpc import RPCSession |
| from .transport import TransportLogger |
| |
| try: |
| from .base import _rpc_connect |
| except ImportError: |
| raise ImportError("micro tvm is not enabled. Set USE_MICRO to ON in config.cmake") |
| |
| |
| class Session: |
| """MicroTVM Device Session |
| |
| Parameters |
| ---------- |
| config : dict |
| configuration for this session (as generated by |
| `tvm.micro.device.host.default_config()`, for example) |
| |
| Example |
| -------- |
| .. code-block:: python |
| |
| c_mod = ... # some module generated with "c" as the target |
| dev_config = micro.device.arm.stm32f746xx.default_config('127.0.0.1', 6666) |
| with tvm.micro.Session(dev_config) as sess: |
| micro_mod = sess.create_micro_mod(c_mod) |
| """ |
| |
| def __init__( |
| self, binary=None, flasher=None, transport_context_manager=None, session_name="micro-rpc" |
| ): |
| """Configure a new session. |
| |
| Parameters |
| ---------- |
| binary : MicroBinary |
| If given, `flasher` must also be given. During session initialization, this binary will |
| be flashed to the device before the transport is created. |
| flasher : Flasher |
| If given, `binary` must also be given. Used to flash `binary` during session |
| initialization. |
| transport_context_manager : ContextManager[transport.Transport] |
| If given, `flasher` and `binary` should not be given. On entry, this context manager |
| should establish a tarnsport between this TVM instance and the device. |
| session_name : str |
| Name of the session, used for debugging. |
| """ |
| self.binary = binary |
| self.flasher = flasher |
| self.transport_context_manager = transport_context_manager |
| self.session_name = session_name |
| |
| self._rpc = None |
| self._graph_runtime = None |
| |
| def get_system_lib(self): |
| return self._rpc.get_function("runtime.SystemLib")() |
| |
| def __enter__(self): |
| """Initialize this session and establish an RPC session with the on-device RPC server. |
| |
| Returns |
| ------- |
| Session : |
| Returns self. |
| """ |
| if self.flasher is not None: |
| self.transport_context_manager = self.flasher.flash(self.binary) |
| time.sleep(3.0) |
| |
| self.transport = TransportLogger( |
| self.session_name, self.transport_context_manager, level=logging.INFO |
| ).__enter__() |
| self._rpc = RPCSession( |
| _rpc_connect(self.session_name, self.transport.write, self.transport.read) |
| ) |
| self.context = self._rpc.cpu(0) |
| return self |
| |
| def __exit__(self, exc_type, exc_value, exc_traceback): |
| """Tear down this session and associated RPC session resources.""" |
| self.transport.__exit__(exc_type, exc_value, exc_traceback) |
| |
| |
| def create_local_graph_runtime(graph_json_str, mod, ctx): |
| """Create a local graph runtime driving execution on the remote CPU context given. |
| |
| Parameters |
| ---------- |
| graph_json_str : str |
| A string containing the graph representation. |
| |
| mod : tvm.runtime.Module |
| The remote module containing functions in graph_json_str. |
| |
| ctx : tvm.Context |
| The remote CPU execution context. |
| |
| Returns |
| ------- |
| tvm.contrib.GraphRuntime : |
| A local graph runtime instance that executes on the remote device. |
| """ |
| device_type_id = [ctx.device_type, ctx.device_id] |
| fcreate = get_global_func("tvm.graph_runtime.create") |
| return graph_runtime.GraphModule(fcreate(graph_json_str, mod, *device_type_id)) |