blob: a1d57fcb5436e15e8ac5919e08833ccb01f8cce1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <gtest/gtest.h>
#include <tvm/runtime/crt/memory.h>
#include <tvm/runtime/crt/rpc_common/frame_buffer.h>
#include <tvm/runtime/crt/rpc_common/session.h>
#include <string>
#include <vector>
#include "buffer_write_stream.h"
#include "crt_config.h"
#include "platform.cc"
using ::tvm::runtime::micro_rpc::Framer;
using ::tvm::runtime::micro_rpc::MessageType;
using ::tvm::runtime::micro_rpc::Session;
using ::tvm::runtime::micro_rpc::Unframer;
extern "C" {
void TestSessionMessageReceivedThunk(void* context, MessageType message_type, FrameBuffer* buf);
}
class ReceivedMessage {
public:
ReceivedMessage(MessageType type, std::string message) : type{type}, message{message} {}
bool operator==(const ReceivedMessage& other) const {
return other.type == type && other.message == message;
}
MessageType type;
std::string message;
};
class TestSession {
public:
TestSession(uint8_t initial_nonce)
: framer{&framer_write_stream},
receive_buffer{receive_buffer_array, sizeof(receive_buffer_array)},
sess{initial_nonce, &framer, &receive_buffer, TestSessionMessageReceivedThunk, this},
unframer{sess.Receiver()} {}
void WriteTo(TestSession* other) {
auto framer_buffer = framer_write_stream.BufferContents();
size_t bytes_to_write = framer_buffer.size();
const uint8_t* write_cursor = reinterpret_cast<const uint8_t*>(framer_buffer.data());
while (bytes_to_write > 0) {
size_t bytes_consumed;
auto to_return = other->unframer.Write(write_cursor, bytes_to_write, &bytes_consumed);
EXPECT_EQ(to_return, kTvmErrorNoError);
bytes_to_write -= bytes_consumed;
write_cursor += bytes_consumed;
}
}
void ClearBuffers() {
framer_write_stream.Reset();
messages_received.clear();
sess.ClearReceiveBuffer();
}
std::vector<ReceivedMessage> messages_received;
BufferWriteStream<300> framer_write_stream;
Framer framer;
uint8_t receive_buffer_array[300];
FrameBuffer receive_buffer;
Session sess;
Unframer unframer;
};
#define EXPECT_FRAMED_PACKET(session, expected) \
EXPECT_EQ(std::string(expected, sizeof(expected) - 1), \
(session).framer_write_stream.BufferContents());
extern "C" {
void TestSessionMessageReceivedThunk(void* context, MessageType message_type, FrameBuffer* buf) {
std::string message;
if (message_type != MessageType::kStartSessionReply) {
uint8_t message_buf[300];
EXPECT_LE(buf->ReadAvailable(), sizeof(message_buf));
size_t message_size_bytes = buf->Read(message_buf, sizeof(message_buf));
message = std::string(reinterpret_cast<char*>(message_buf), message_size_bytes);
}
static_cast<TestSession*>(context)->messages_received.emplace_back(
ReceivedMessage(message_type, message));
}
}
void PrintTo(tvm_crt_error_t p, std::ostream* os) {
std::ios_base::fmtflags f(os->flags());
*os << "tvm_crt_error_t(0x" << std::hex << std::setw(8) << std::setfill('0') << p << ")";
os->flags(f);
}
void PrintTo(ReceivedMessage msg, std::ostream* os) {
*os << "ReceivedMessage(" << int(msg.type) << ", \"" << msg.message << "\")";
}
class SessionTest : public ::testing::Test {
public:
static constexpr const uint8_t kAliceNonce = 0x3c;
static constexpr const uint8_t kBobNonce = 0xab;
TestSession alice_{kAliceNonce};
TestSession bob_{kBobNonce};
};
TEST_F(SessionTest, NormalExchange) {
tvm_crt_error_t err;
err = alice_.sess.Initialize();
EXPECT_EQ(kTvmErrorNoError, err);
EXPECT_FRAMED_PACKET(alice_,
"\xfe\xff\xfd\x03\0\0\0\0\0\x02"
"fw");
alice_.WriteTo(&bob_);
err = bob_.sess.Initialize();
EXPECT_EQ(kTvmErrorNoError, err);
EXPECT_FRAMED_PACKET(bob_,
"\xfe\xff\xfd\x03\0\0\0\0\0\x02"
"fw");
alice_.WriteTo(&alice_);
bob_.ClearBuffers();
alice_.ClearBuffers();
err = alice_.sess.StartSession();
EXPECT_EQ(err, kTvmErrorNoError);
EXPECT_FRAMED_PACKET(alice_, "\xff\xfd\x04\0\0\0\x82\0\0\x01{\xE9");
bob_.ClearBuffers();
alice_.WriteTo(&bob_);
EXPECT_FRAMED_PACKET(bob_,
"\xff\xfd\x4\0\0\0\x82"
"f\x01\x01\x81\xf3");
EXPECT_TRUE(bob_.sess.IsEstablished());
bob_.WriteTo(&alice_);
EXPECT_TRUE(alice_.sess.IsEstablished());
ASSERT_EQ(alice_.messages_received.size(), 1);
EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kStartSessionReply, ""));
alice_.ClearBuffers();
alice_.sess.SendMessage(MessageType::kNormal, reinterpret_cast<const uint8_t*>("hello"), 5);
EXPECT_FRAMED_PACKET(alice_,
"\xFF\xFD\b\0\0\0\x82"
"f\x10hello\x90(");
alice_.WriteTo(&bob_);
ASSERT_EQ(bob_.messages_received.size(), 2);
EXPECT_EQ(bob_.messages_received[0], ReceivedMessage(MessageType::kStartSessionReply, ""));
EXPECT_EQ(bob_.messages_received[1], ReceivedMessage(MessageType::kNormal, "hello"));
bob_.ClearBuffers();
bob_.sess.SendMessage(MessageType::kNormal, reinterpret_cast<const uint8_t*>("olleh"), 5);
EXPECT_FRAMED_PACKET(bob_,
"\xff\xfd\b\0\0\0\x82"
"f\x10ollehLv");
bob_.WriteTo(&alice_);
ASSERT_EQ(alice_.messages_received.size(), 1);
EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kNormal, "olleh"));
alice_.ClearBuffers();
bob_.ClearBuffers();
alice_.sess.SendMessage(MessageType::kLog, reinterpret_cast<const uint8_t*>("log1"), 4);
EXPECT_FRAMED_PACKET(alice_, "\xff\xfd\a\0\0\0\0\0\x03log1\xf0\xd4");
alice_.WriteTo(&bob_);
ASSERT_EQ(bob_.messages_received.size(), 1);
EXPECT_EQ(bob_.messages_received[0], ReceivedMessage(MessageType::kLog, "log1"));
bob_.sess.SendMessage(MessageType::kLog, reinterpret_cast<const uint8_t*>("zero"), 4);
EXPECT_FRAMED_PACKET(bob_, "\xff\xfd\a\0\0\0\0\0\x03zero\xb2h");
bob_.WriteTo(&alice_);
ASSERT_EQ(alice_.messages_received.size(), 1);
EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kLog, "zero"));
}
TEST_F(SessionTest, LogBeforeSessionStart) {
alice_.sess.SendMessage(MessageType::kLog, reinterpret_cast<const uint8_t*>("log1"), 4);
EXPECT_FRAMED_PACKET(alice_, "\xfe\xff\xfd\a\0\0\0\0\0\x03log1\xf0\xd4");
alice_.WriteTo(&bob_);
ASSERT_EQ(bob_.messages_received.size(), 1);
EXPECT_EQ(bob_.messages_received[0], ReceivedMessage(MessageType::kLog, "log1"));
bob_.sess.SendMessage(MessageType::kLog, reinterpret_cast<const uint8_t*>("zero"), 4);
EXPECT_FRAMED_PACKET(bob_, "\xfe\xff\xfd\a\0\0\0\0\0\x03zero\xb2h");
bob_.WriteTo(&alice_);
ASSERT_EQ(alice_.messages_received.size(), 1);
EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kLog, "zero"));
}
static constexpr const char kBobStartPacket[] = "\xff\xfd\x04\0\0\0f\0\0\x01`\xa7";
TEST_F(SessionTest, DoubleStart) {
tvm_crt_error_t err;
err = alice_.sess.Initialize();
EXPECT_EQ(kTvmErrorNoError, err);
EXPECT_FRAMED_PACKET(alice_,
"\xfe\xff\xfd\x03\0\0\0\0\0\x02"
"fw");
alice_.WriteTo(&bob_);
err = bob_.sess.Initialize();
EXPECT_EQ(kTvmErrorNoError, err);
EXPECT_FRAMED_PACKET(bob_,
"\xfe\xff\xfd\x03\0\0\0\0\0\x02"
"fw");
alice_.WriteTo(&alice_);
bob_.ClearBuffers();
alice_.ClearBuffers();
EXPECT_EQ(kTvmErrorNoError, alice_.sess.StartSession());
EXPECT_FRAMED_PACKET(alice_, "\xff\xfd\x04\0\0\0\x82\0\0\x01{\xe9");
EXPECT_FALSE(alice_.sess.IsEstablished());
EXPECT_EQ(kTvmErrorNoError, bob_.sess.StartSession());
EXPECT_FRAMED_PACKET(bob_, kBobStartPacket);
EXPECT_FALSE(bob_.sess.IsEstablished());
// Sending Alice -> Bob should have no effect (regenerated Bob nonce > regenerated Alice nonce).
bob_.framer_write_stream.Reset();
alice_.WriteTo(&bob_);
EXPECT_FRAMED_PACKET(bob_, "");
EXPECT_FALSE(bob_.sess.IsEstablished());
// Sending Bob -> Alice should start the session.
alice_.ClearBuffers();
size_t bytes_consumed;
EXPECT_EQ(kTvmErrorNoError,
alice_.unframer.Write(reinterpret_cast<const uint8_t*>(kBobStartPacket),
sizeof(kBobStartPacket), &bytes_consumed));
EXPECT_EQ(bytes_consumed, sizeof(kBobStartPacket));
EXPECT_FRAMED_PACKET(alice_, "\xFF\xFD\x4\0\0\0fE\x01\x01\fb");
EXPECT_TRUE(alice_.sess.IsEstablished());
bob_.ClearBuffers();
alice_.WriteTo(&bob_);
EXPECT_TRUE(bob_.sess.IsEstablished());
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}