blob: 1ac81a0245fbdfe832c5ecce8f1025b170e78e6c [file] [log] [blame]
/*
* Copyright (c) 2019-2022 ExpoLab, UC Davis
*
* Permission is hereby granted, free of charge, to any person
* obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without
* restriction, including without limitation the rights to use,
* copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*/
#include "interface/rdbc/net_channel.h"
#include <gtest/gtest.h>
#include "common/crypto/key_generator.h"
#include "platform/common/network/mock_socket.h"
#include "platform/proto/client_test.pb.h"
namespace resdb {
namespace {
using ::testing::Invoke;
using ::testing::Return;
using ::testing::Test;
class NetChannelTest : public Test {
protected:
ResDBMessage GetSendPackage(const google::protobuf::Message& message,
Request::Type type) {
ResDBMessage resdb_message;
Request request;
request.set_type(type);
EXPECT_TRUE(message.SerializeToString(request.mutable_data()));
EXPECT_TRUE(request.SerializeToString(resdb_message.mutable_data()));
return resdb_message;
}
std::string GetSendPackageString(const google::protobuf::Message& message,
Request::Type type) {
ResDBMessage resdb_message = GetSendPackage(message, type);
std::string data;
EXPECT_TRUE(resdb_message.SerializeToString(&data));
return data;
}
ResDBMessage GetSendPackage(const google::protobuf::Message& message) {
ResDBMessage resdb_message;
EXPECT_TRUE(message.SerializeToString(resdb_message.mutable_data()));
return resdb_message;
}
std::string GetSendPackageString(const google::protobuf::Message& message) {
ResDBMessage resdb_message = GetSendPackage(message);
std::string data;
EXPECT_TRUE(resdb_message.SerializeToString(&data));
return data;
}
};
TEST_F(NetChannelTest, ClientConnectFail) {
std::unique_ptr<MockSocket> socket = std::make_unique<MockSocket>();
EXPECT_CALL(*socket, Connect("127.0.0.1", 1234))
.Times(3)
.WillRepeatedly(Return(-1));
NetChannel client("127.0.0.1", 1234);
client.SetSocket(std::move(socket));
ClientTestRequest client_request;
std::string data;
ASSERT_TRUE(client_request.SerializeToString(&data));
int ret = client.SendRawMessage(client_request);
EXPECT_LE(ret, 0);
}
TEST_F(NetChannelTest, KeepAliveAndAsyncReConnect) {
std::unique_ptr<MockSocket> socket = std::make_unique<MockSocket>();
EXPECT_CALL(*socket, Connect("127.0.0.1", 1234))
.Times(3)
.WillRepeatedly(Return(-1));
EXPECT_CALL(*socket, SetAsync(true)).Times(3).WillRepeatedly(Return(0));
NetChannel client("127.0.0.1", 1234);
client.SetAsyncSend(true);
client.IsLongConnection(true);
client.SetSocket(std::move(socket));
ClientTestRequest client_request;
std::string data;
ASSERT_TRUE(client_request.SerializeToString(&data));
int ret = client.SendRawMessage(client_request);
EXPECT_LE(ret, 0);
}
TEST_F(NetChannelTest, SendRequest) {
std::string data = "test_data";
ClientTestRequest client_request;
client_request.set_value("test_value");
std::string expected_request_str =
GetSendPackageString(client_request, Request::TYPE_CLIENT_REQUEST);
std::unique_ptr<MockSocket> socket = std::make_unique<MockSocket>();
EXPECT_CALL(*socket, Connect("127.0.0.1", 1234)).WillOnce(Return(0));
EXPECT_CALL(*socket, Send(expected_request_str)).WillOnce(Return(0));
NetChannel client("127.0.0.1", 1234);
client.SetSocket(std::move(socket));
client.SetSignatureVerifier(nullptr);
int ret = client.SendRequest(client_request, Request::TYPE_CLIENT_REQUEST);
EXPECT_EQ(ret, 0);
}
TEST_F(NetChannelTest, SendRequestFail) {
std::string data = "test_data";
ClientTestRequest client_request;
client_request.set_value("test_value");
std::string expected_request_str =
GetSendPackageString(client_request, Request::TYPE_CLIENT_REQUEST);
std::unique_ptr<MockSocket> socket = std::make_unique<MockSocket>();
EXPECT_CALL(*socket, ReInit()).Times(3);
EXPECT_CALL(*socket, Connect("127.0.0.1", 1234)).Times(3).WillOnce(Return(0));
EXPECT_CALL(*socket, Send(expected_request_str))
.Times(3)
.WillRepeatedly(Return(-1));
NetChannel client("127.0.0.1", 1234);
client.SetSocket(std::move(socket));
client.SetSignatureVerifier(nullptr);
int ret = client.SendRequest(client_request, Request::TYPE_CLIENT_REQUEST);
EXPECT_NE(ret, 0);
}
TEST_F(NetChannelTest, SendMessage) {
ClientTestRequest client_request;
client_request.set_value("test_value");
std::string data = GetSendPackageString(client_request);
std::unique_ptr<MockSocket> socket = std::make_unique<MockSocket>();
EXPECT_CALL(*socket, Connect("127.0.0.1", 1234)).WillOnce(Return(0));
EXPECT_CALL(*socket, Send(data)).WillOnce(Return(0));
NetChannel client("127.0.0.1", 1234);
client.SetSocket(std::move(socket));
client.SetSignatureVerifier(nullptr);
EXPECT_EQ(client.SendRawMessage(client_request), 0);
}
TEST_F(NetChannelTest, SignMessage) {
SecretKey my_key = KeyGenerator ::GeneratorKeys(SignatureInfo::RSA);
int64_t my_node_id = 1;
KeyInfo private_key;
private_key.set_key(my_key.private_key());
private_key.set_hash_type(my_key.hash_type());
CertificateInfo cert_info;
cert_info.set_node_id(my_node_id);
cert_info.mutable_public_key()
->mutable_public_key_info()
->mutable_key()
->set_key(my_key.public_key());
cert_info.mutable_public_key()
->mutable_public_key_info()
->mutable_key()
->set_hash_type(my_key.hash_type());
cert_info.mutable_public_key()->mutable_public_key_info()->set_node_id(
my_node_id);
SignatureVerifier verifier(private_key, cert_info);
ClientTestRequest client_request;
client_request.set_value("test_value");
std::string data = GetSendPackageString(client_request);
std::unique_ptr<MockSocket> socket = std::make_unique<MockSocket>();
EXPECT_CALL(*socket, Connect("127.0.0.1", 1234)).WillOnce(Return(0));
EXPECT_CALL(*socket, Send).WillOnce(Invoke([&](const std::string& data) {
ResDBMessage resdb_message;
EXPECT_TRUE(resdb_message.ParseFromString(data));
EXPECT_TRUE(verifier.VerifyMessage(resdb_message.data(),
resdb_message.signature()));
return 0;
}));
NetChannel client("127.0.0.1", 1234);
client.SetSocket(std::move(socket));
client.SetSignatureVerifier(&verifier);
EXPECT_EQ(client.SendRawMessage(client_request), 0);
}
} // namespace
} // namespace resdb