blob: 096618b8b4923cf0aad208f482475a4289e974b4 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <stdint.h>
#include <algorithm>
#include <chrono>
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
#include "client/partition_resolver.h"
#include "dsn.layer2_types.h"
#include "rpc_address.h"
#include "runtime/api_layer1.h"
#include "runtime/rpc/rpc_message.h"
#include "runtime/rpc/serialization.h"
#include "runtime/task/async_calls.h"
#include "runtime/task/task.h"
#include "runtime/task/task_code.h"
#include "utils/autoref_ptr.h"
#include "utils/chrono_literals.h"
#include "utils/error_code.h"
#include "utils/fmt_logging.h"
#include "utils/function_traits.h"
#include "utils/ports.h"
namespace dsn {
class task_tracker;
using literals::chrono_literals::operator"" _ms;
//
// rpc_holder is mainly designed for RAII of message_ex*.
// Since the request message will be automatically released after the rpc ends,
// it will become inaccessible when you use it in an async call (probably via tasking::enqueue).
// So in rpc_holder we hold another reference of the message, preventing it to be deleted.
//
// rpc_holder also provides a simple approach for rpc mocking.
// For example:
//
// typedef rpc_holder<write_request, write_response> write_rpc;
//
// void write() {
// ....
// auto request = std::make_unique<write_request>();
// request->data = "abc";
// request->timestamp = 12;
// write_rpc rpc(std::move(request), RPC_WRITE);
// rpc.call(rpc_address("10.57.223.31", 12321), nullptr, on_write_rpc_reply);
// ...
// }
//
// RPC_MOCKING(write_rpc) {
// write();
// ASSERT_EQ(1, write_rpc::mail_box().size());
// }
//
// Here in the instance of `RPC_MOCKING`, we call a function `write`, in which a write_rpc
// was sent to "10.57.223.31:12321". However, since we are in the mock mode, every `write_request`
// message will be dropped into `write_rpc::mail_box` without going through network.
//
template <typename TRequest, typename TResponse>
class rpc_holder
{
public:
using request_type = TRequest;
using response_type = TResponse;
public:
explicit rpc_holder(message_ex *req = nullptr)
{
if (req != nullptr) {
_i = std::make_shared<internal>(req);
}
}
rpc_holder(std::unique_ptr<TRequest> req,
task_code code,
std::chrono::milliseconds timeout = 0_ms,
uint64_t partition_hash = 0,
int thread_hash = 0)
: _i(new internal(req, code, timeout, partition_hash, thread_hash))
{
}
// copyable and movable
// Copying an rpc_holder doesn't produce a deep copy, the new instance will
// reference the same rpc internal data. So, just feel free to copy :)
rpc_holder(const rpc_holder &) = default;
rpc_holder(rpc_holder &&) noexcept = default;
rpc_holder &operator=(const rpc_holder &) = default;
rpc_holder &operator=(rpc_holder &&) noexcept = default;
bool is_initialized() const { return bool(_i); }
const TRequest &request() const
{
CHECK(_i, "rpc_holder is uninitialized");
return *(_i->thrift_request);
}
TRequest *mutable_request() const
{
CHECK(_i, "rpc_holder is uninitialized");
return _i->thrift_request.get();
}
TResponse &response() const
{
CHECK(_i, "rpc_holder is uninitialized");
return _i->thrift_response;
}
dsn::error_code &error() const
{
CHECK(_i, "rpc_holder is uninitialized");
return _i->rpc_error;
}
message_ex *dsn_request() const
{
CHECK(_i, "rpc_holder is uninitialized");
return _i->dsn_request;
}
// the remote address where reveice request from and send response to.
rpc_address remote_address() const { return dsn_request()->header->from_address; }
// TCallback = void(error_code)
// NOTE that the `error_code` is not the error carried by response. Users should
// check the responded error themselves.
template <typename TCallback>
task_ptr call(const rpc_address &server,
task_tracker *tracker,
TCallback &&callback,
int reply_thread_hash = 0)
{
// ensures that TCallback receives exactly one argument, which must be a error_code.
static_assert(function_traits<TCallback>::arity == 1,
"TCallback must receive exactly one argument");
static_assert(
std::is_same<typename function_traits<TCallback>::template arg_t<0>, error_code>::value,
"the first argument of TCallback must be error_code");
if (dsn_unlikely(_mail_box != nullptr)) {
_mail_box->emplace_back(*this);
return nullptr;
}
rpc_response_task_ptr t = rpc::create_rpc_response_task(
dsn_request(),
tracker,
[ cb_fwd = std::forward<TCallback>(callback),
rpc = *this ](error_code err, message_ex * req, message_ex * resp) mutable {
if (err == ERR_OK) {
unmarshall(resp, rpc.response());
}
cb_fwd(err);
},
reply_thread_hash);
dsn_rpc_call(server, t);
return t;
}
template <typename TCallback>
task_ptr call(replication::partition_resolver_ptr &resolver,
task_tracker *tracker,
TCallback &&callback,
int reply_thread_hash = 0)
{
static_assert(function_traits<TCallback>::arity == 1,
"TCallback must receive exactly one argument");
static_assert(
std::is_same<typename function_traits<TCallback>::template arg_t<0>, error_code>::value,
"the first argument of TCallback must be error_code");
if (dsn_unlikely(_mail_box != nullptr)) {
_mail_box->emplace_back(*this);
return nullptr;
}
rpc_response_task_ptr t = rpc::create_rpc_response_task(
dsn_request(),
tracker,
[ cb_fwd = std::forward<TCallback>(callback),
rpc = *this ](error_code err, message_ex * req, message_ex * resp) mutable {
if (err == ERR_OK) {
unmarshall(resp, rpc.response());
}
cb_fwd(err);
},
reply_thread_hash);
resolver->call_task(t);
return t;
}
void forward(const rpc_address &addr)
{
_i->auto_reply = false;
if (dsn_unlikely(_forward_mail_box != nullptr)) {
dsn_request()->header->from_address = addr;
_forward_mail_box->emplace_back(*this);
return;
}
dsn_rpc_forward(dsn_request(), addr);
}
// Returns an rpc_holder that will reply the request after its lifetime ends.
// By default rpc_holder never replies.
// SEE: serverlet<T>::register_rpc_handler_with_rpc_holder
static inline rpc_holder auto_reply(message_ex *req)
{
rpc_holder rpc(req);
rpc._i->auto_reply = true;
return rpc;
}
// Only use this function when testing.
// In mock mode, all messages will be dropped into mail_box without going through network,
// and response callbacks will never be called.
// This function is not thread-safe.
using mail_box_t = std::vector<rpc_holder<TRequest, TResponse>>;
using mail_box_u_ptr = std::unique_ptr<mail_box_t>;
static void enable_mocking()
{
CHECK(_mail_box == nullptr && _forward_mail_box == nullptr,
"remember to call clear_mocking_env after testing");
_mail_box = std::make_unique<mail_box_t>();
_forward_mail_box = std::make_unique<mail_box_t>();
}
// Only use this function when testing.
// Remember to call it after test finishes, or it may effect the results of other tests.
// This function is not thread-safe.
static void clear_mocking_env()
{
_mail_box.reset(nullptr);
_forward_mail_box.reset(nullptr);
}
static mail_box_t &mail_box()
{
CHECK(_mail_box, "call this function only when you are in mock mode");
return *_mail_box.get();
}
static mail_box_t &forward_mail_box()
{
CHECK(_forward_mail_box, "call this function only when you are in mock mode");
return *_forward_mail_box.get();
}
friend bool operator<(const rpc_holder &lhs, const rpc_holder &rhs) { return lhs._i < rhs._i; }
private:
friend class rpc_holder_test;
struct internal
{
explicit internal(message_ex *req)
: dsn_request(req), thrift_request(std::make_unique<TRequest>()), auto_reply(false)
{
// we must hold one reference for the request, or rdsn will delete it after
// the rpc call ends.
dsn_request->add_ref();
unmarshall(req, *thrift_request);
}
internal(std::unique_ptr<TRequest> &req,
task_code code,
std::chrono::milliseconds timeout,
uint64_t partition_hash,
int thread_hash)
: thrift_request(std::move(req)), auto_reply(false)
{
CHECK(thrift_request, "req should not be null");
dsn_request = message_ex::create_request(
code, static_cast<int>(timeout.count()), thread_hash, partition_hash);
dsn_request->add_ref();
marshall(dsn_request, *thrift_request);
}
void reply()
{
if (dsn_unlikely(_mail_box != nullptr)) {
rpc_holder<TRequest, TResponse> rpc(std::move(thrift_request),
dsn_request->rpc_code());
rpc.response() = std::move(thrift_response);
_mail_box->emplace_back(std::move(rpc));
return;
}
message_ex *dsn_response = dsn_request->create_response();
marshall(dsn_response, thrift_response);
dsn_rpc_reply(dsn_response, rpc_error);
}
~internal()
{
if (auto_reply) {
reply();
}
dsn_request->release_ref();
}
message_ex *dsn_request;
std::unique_ptr<TRequest> thrift_request;
TResponse thrift_response;
dsn::error_code rpc_error = dsn::ERR_OK;
bool auto_reply;
};
std::shared_ptr<internal> _i;
static mail_box_u_ptr _mail_box;
static mail_box_u_ptr _forward_mail_box;
};
// ======== type traits ========
// check if a given type is rpc_holder.
// is_rpc_holder<T>::value = true indicates that type T is an rpc_holder.
template <typename T>
struct is_rpc_holder : public std::false_type
{
};
template <typename TRequest, typename TResponse>
struct is_rpc_holder<rpc_holder<TRequest, TResponse>> : public std::true_type
{
};
// ======== utilities ========
namespace rpc {
// call an RPC specified by rpc_holder.
// TCallback = void(error_code)
template <typename TCallback, typename TRpcHolder>
task_ptr call(rpc_address server,
TRpcHolder rpc,
task_tracker *tracker,
TCallback &&callback,
int reply_thread_hash = 0)
{
static_assert(is_rpc_holder<TRpcHolder>::value, "TRpcHolder must be an rpc_holder");
return rpc.call(server, tracker, std::forward<TCallback &&>(callback), reply_thread_hash);
}
} // namespace rpc
// ======== rpc mock ========
template <typename TRequest, typename TResponse>
typename rpc_holder<TRequest, TResponse>::mail_box_u_ptr rpc_holder<TRequest, TResponse>::_mail_box;
template <typename TRequest, typename TResponse>
typename rpc_holder<TRequest, TResponse>::mail_box_u_ptr
rpc_holder<TRequest, TResponse>::_forward_mail_box;
template <typename TRpcHolder>
struct rpc_mock_wrapper
{
rpc_mock_wrapper() { TRpcHolder::enable_mocking(); }
~rpc_mock_wrapper() { TRpcHolder::clear_mocking_env(); }
int counter = 0;
};
#define RPC_MOCKING(__rpc_type__) \
for (::dsn::rpc_mock_wrapper<__rpc_type__> __rpc_type__##_mocking__; \
__rpc_type__##_mocking__.counter != 1; \
__rpc_type__##_mocking__.counter++)
} // namespace dsn