blob: 804f83587f009f87c74b89782838d642b4604167 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for common micro transports."""
import logging
import sys
import unittest
import pytest
import tvm.testing
# Implementing as a fixture so that the tvm.micro import doesn't occur
# until fixture setup time. This is necessary for pytest's collection
# phase to work when USE_MICRO=OFF, while still explicitly listing the
# tests as skipped.
@tvm.testing.fixture
def transport():
import tvm.micro
class MockTransport_Impl(tvm.micro.transport.Transport):
def __init__(self):
self.exc = None
self.to_return = None
def _raise_or_return(self):
if self.exc is not None:
to_raise = self.exc
self.exc = None
raise to_raise
elif self.to_return is not None:
to_return = self.to_return
self.to_return = None
return to_return
else:
assert False, "should not get here"
def open(self):
pass
def close(self):
pass
def timeouts(self):
raise NotImplementedError()
def read(self, n, timeout_sec):
return self._raise_or_return()
def write(self, data, timeout_sec):
return self._raise_or_return()
return MockTransport_Impl()
@tvm.testing.fixture
def transport_logger(transport):
logger = logging.getLogger("transport_logger_test")
return tvm.micro.transport.TransportLogger("foo", transport, logger=logger)
@tvm.testing.fixture
def get_latest_log(caplog):
def inner():
return caplog.records[-1].getMessage()
with caplog.at_level(logging.INFO, "transport_logger_test"):
yield inner
@tvm.testing.requires_micro
def test_open(transport_logger, get_latest_log):
transport_logger.open()
assert get_latest_log() == "foo: opening transport"
@tvm.testing.requires_micro
def test_close(transport_logger, get_latest_log):
transport_logger.close()
assert get_latest_log() == "foo: closing transport"
@tvm.testing.requires_micro
def test_read_normal(transport, transport_logger, get_latest_log):
transport.to_return = b"data"
transport_logger.read(23, 3.0)
assert get_latest_log() == (
"foo: read { 3.00s} 23 B -> [ 4 B]: 64 61 74 61"
" data"
)
@tvm.testing.requires_micro
def test_read_multiline(transport, transport_logger, get_latest_log):
transport.to_return = b"data" * 6
transport_logger.read(23, 3.0)
assert get_latest_log() == (
"foo: read { 3.00s} 23 B -> [ 24 B]:\n"
"0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n"
"0010 64 61 74 61 64 61 74 61 datadata"
)
@tvm.testing.requires_micro
def test_read_no_timeout_prints(transport, transport_logger, get_latest_log):
transport.to_return = b"data"
transport_logger.read(15, None)
assert get_latest_log() == (
"foo: read { None } 15 B -> [ 4 B]: 64 61 74 61"
" data"
)
@tvm.testing.requires_micro
def test_read_io_timeout(transport, transport_logger, get_latest_log):
# IoTimeoutError includes the timeout value.
transport.exc = tvm.micro.transport.IoTimeoutError()
with pytest.raises(tvm.micro.transport.IoTimeoutError):
transport_logger.read(23, 0.0)
assert get_latest_log() == ("foo: read { 0.00s} 23 B -> [IoTimeoutError 0.00s]")
@tvm.testing.requires_micro
def test_read_other_exception(transport, transport_logger, get_latest_log):
# Other exceptions are logged by name.
transport.exc = tvm.micro.transport.TransportClosedError()
with pytest.raises(tvm.micro.transport.TransportClosedError):
transport_logger.read(8, 0.0)
assert get_latest_log() == ("foo: read { 0.00s} 8 B -> [err: TransportClosedError]")
@tvm.testing.requires_micro
def test_read_keyboard_interrupt(transport, transport_logger, get_latest_log):
# KeyboardInterrupt produces no log record.
transport.exc = KeyboardInterrupt()
with pytest.raises(KeyboardInterrupt):
transport_logger.read(8, 0.0)
with pytest.raises(IndexError):
get_latest_log()
@tvm.testing.requires_micro
def test_write_normal(transport, transport_logger, get_latest_log):
transport.to_return = 3
transport_logger.write(b"data", 3.0)
assert get_latest_log() == (
"foo: write { 3.00s} <- [ 4 B]: 64 61 74 61"
" data"
)
@tvm.testing.requires_micro
def test_write_multiline(transport, transport_logger, get_latest_log):
# Normal log, multi-line data written.
transport.to_return = 20
transport_logger.write(b"data" * 6, 3.0)
assert get_latest_log() == (
"foo: write { 3.00s} <- [ 24 B]:\n"
"0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n"
"0010 64 61 74 61 64 61 74 61 datadata"
)
@tvm.testing.requires_micro
def test_write_no_timeout_prints(transport, transport_logger, get_latest_log):
transport.to_return = 3
transport_logger.write(b"data", None)
assert get_latest_log() == (
"foo: write { None } <- [ 4 B]: 64 61 74 61"
" data"
)
@tvm.testing.requires_micro
def test_write_io_timeout(transport, transport_logger, get_latest_log):
# IoTimeoutError includes the timeout value.
transport.exc = tvm.micro.transport.IoTimeoutError()
with pytest.raises(tvm.micro.transport.IoTimeoutError):
transport_logger.write(b"data", 0.0)
assert get_latest_log() == ("foo: write { 0.00s} <- [ 4 B]: [IoTimeoutError 0.00s]")
@tvm.testing.requires_micro
def test_write_other_exception(transport, transport_logger, get_latest_log):
# Other exceptions are logged by name.
transport.exc = tvm.micro.transport.TransportClosedError()
with pytest.raises(tvm.micro.transport.TransportClosedError):
transport_logger.write(b"data", 0.0)
assert get_latest_log() == ("foo: write { 0.00s} <- [ 4 B]: [err: TransportClosedError]")
@tvm.testing.requires_micro
def test_write_keyboard_interrupt(transport, transport_logger, get_latest_log):
# KeyboardInterrupt produces no log record.
transport.exc = KeyboardInterrupt()
with pytest.raises(KeyboardInterrupt):
transport_logger.write(b"data", 0.0)
with pytest.raises(IndexError):
get_latest_log()
if __name__ == "__main__":
tvm.testing.main()