blob: 3565040e1d76f7b868b41a59d35a55bbd6ba60c8 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Defines a top-level glue class that operates the Transport and Flasher classes."""
import logging
import time
from .._ffi import get_global_func
from ..contrib import graph_runtime
from ..rpc import RPCSession
from .transport import TransportLogger
try:
from .base import _rpc_connect
except ImportError:
raise ImportError("micro tvm is not enabled. Set USE_MICRO to ON in config.cmake")
class Session:
"""MicroTVM Device Session
Parameters
----------
config : dict
configuration for this session (as generated by
`tvm.micro.device.host.default_config()`, for example)
Example
--------
.. code-block:: python
c_mod = ... # some module generated with "c" as the target
dev_config = micro.device.arm.stm32f746xx.default_config('127.0.0.1', 6666)
with tvm.micro.Session(dev_config) as sess:
micro_mod = sess.create_micro_mod(c_mod)
"""
def __init__(
self, binary=None, flasher=None, transport_context_manager=None, session_name="micro-rpc"
):
"""Configure a new session.
Parameters
----------
binary : MicroBinary
If given, `flasher` must also be given. During session initialization, this binary will
be flashed to the device before the transport is created.
flasher : Flasher
If given, `binary` must also be given. Used to flash `binary` during session
initialization.
transport_context_manager : ContextManager[transport.Transport]
If given, `flasher` and `binary` should not be given. On entry, this context manager
should establish a tarnsport between this TVM instance and the device.
session_name : str
Name of the session, used for debugging.
"""
self.binary = binary
self.flasher = flasher
self.transport_context_manager = transport_context_manager
self.session_name = session_name
self._rpc = None
self._graph_runtime = None
def get_system_lib(self):
return self._rpc.get_function("runtime.SystemLib")()
def __enter__(self):
"""Initialize this session and establish an RPC session with the on-device RPC server.
Returns
-------
Session :
Returns self.
"""
if self.flasher is not None:
self.transport_context_manager = self.flasher.flash(self.binary)
time.sleep(3.0)
self.transport = TransportLogger(
self.session_name, self.transport_context_manager, level=logging.INFO
).__enter__()
self._rpc = RPCSession(
_rpc_connect(self.session_name, self.transport.write, self.transport.read)
)
self.context = self._rpc.cpu(0)
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
"""Tear down this session and associated RPC session resources."""
self.transport.__exit__(exc_type, exc_value, exc_traceback)
def create_local_graph_runtime(graph_json_str, mod, ctx):
"""Create a local graph runtime driving execution on the remote CPU context given.
Parameters
----------
graph_json_str : str
A string containing the graph representation.
mod : tvm.runtime.Module
The remote module containing functions in graph_json_str.
ctx : tvm.Context
The remote CPU execution context.
Returns
-------
tvm.contrib.GraphRuntime :
A local graph runtime instance that executes on the remote device.
"""
device_type_id = [ctx.device_type, ctx.device_id]
fcreate = get_global_func("tvm.graph_runtime.create")
return graph_runtime.GraphModule(fcreate(graph_json_str, mod, *device_type_id))