blob: b8454070ddf9583583a3941ca1c14062a36ff424 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "kudu/rpc/serialization.h"
#include <limits>
#include <ostream>
#include <gflags/gflags_declare.h>
#include <glog/logging.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/message_lite.h>
#include "kudu/gutil/endian.h"
#include "kudu/gutil/port.h"
#include "kudu/gutil/stringprintf.h"
#include "kudu/gutil/strings/substitute.h"
#include "kudu/rpc/constants.h"
#include "kudu/util/faststring.h"
#include "kudu/util/logging.h"
#include "kudu/util/slice.h"
#include "kudu/util/status.h"
DECLARE_int64(rpc_max_message_size);
using google::protobuf::MessageLite;
using google::protobuf::io::CodedInputStream;
using google::protobuf::io::CodedOutputStream;
using strings::Substitute;
namespace kudu {
namespace rpc {
namespace serialization {
enum {
kHeaderPosVersion = 0,
kHeaderPosServiceClass = 1,
kHeaderPosAuthProto = 2
};
void SerializeMessage(const MessageLite& message, faststring* param_buf,
int additional_size, bool use_cached_size) {
DCHECK_GE(additional_size, 0);
size_t pb_size = use_cached_size ? message.GetCachedSize() : message.ByteSizeLong();
DCHECK_EQ(message.ByteSizeLong(), pb_size);
// Use 8-byte integers to avoid overflowing when additional_size approaches INT_MAX.
int64_t recorded_size = static_cast<int64_t>(pb_size) +
static_cast<int64_t>(additional_size);
int64_t size_with_delim = static_cast<int64_t>(pb_size) +
static_cast<int64_t>(CodedOutputStream::VarintSize32(recorded_size));
int64_t total_size = size_with_delim + static_cast<int64_t>(additional_size);
// The message format relies on an unsigned 32-bit integer to express the size, so
// the message must not exceed this size. Since additional_size is limited to INT_MAX,
// this is a safe limitation.
CHECK_LE(total_size, std::numeric_limits<uint32_t>::max());
if (PREDICT_FALSE(total_size > FLAGS_rpc_max_message_size)) {
LOG(WARNING) << Substitute("Serialized $0 ($1 bytes) is larger than the maximum configured "
"RPC message size ($2 bytes). "
"Sending anyway, but peer may reject the data.",
message.GetTypeName(), total_size, FLAGS_rpc_max_message_size);
}
param_buf->resize(size_with_delim);
uint8_t* dst = param_buf->data();
dst = CodedOutputStream::WriteVarint32ToArray(recorded_size, dst);
dst = message.SerializeWithCachedSizesToArray(dst);
CHECK_EQ(dst, param_buf->data() + size_with_delim);
}
void SerializeHeader(const MessageLite& header,
size_t param_len,
faststring* header_buf) {
CHECK(header.IsInitialized())
<< "RPC header missing fields: " << header.InitializationErrorString();
// Compute all the lengths for the packet.
size_t header_pb_len = header.ByteSizeLong();
size_t header_tot_len = kMsgLengthPrefixLength // Int prefix for the total length.
+ CodedOutputStream::VarintSize32(header_pb_len) // Varint delimiter for header PB.
+ header_pb_len; // Length for the header PB itself.
size_t total_size = header_tot_len + param_len;
header_buf->resize(header_tot_len);
uint8_t* dst = header_buf->data();
// 1. The length for the whole request, not including the 4-byte
// length prefix.
NetworkByteOrder::Store32(dst, total_size - kMsgLengthPrefixLength);
dst += sizeof(uint32_t);
// 2. The varint-prefixed RequestHeader PB
dst = CodedOutputStream::WriteVarint32ToArray(header_pb_len, dst);
dst = header.SerializeWithCachedSizesToArray(dst);
// We should have used the whole buffer we allocated.
CHECK_EQ(dst, header_buf->data() + header_tot_len);
}
Status ParseMessage(const Slice& buf,
MessageLite* parsed_header,
Slice* parsed_main_message) {
// First grab the total length
if (PREDICT_FALSE(buf.size() < kMsgLengthPrefixLength)) {
return Status::Corruption("Invalid packet: not enough bytes for length header",
KUDU_REDACT(buf.ToDebugString()));
}
uint32_t total_len = NetworkByteOrder::Load32(buf.data());
DCHECK_EQ(total_len, buf.size() - kMsgLengthPrefixLength)
<< "Got mis-sized buffer: " << KUDU_REDACT(buf.ToDebugString());
if (total_len > std::numeric_limits<int32_t>::max()) {
return Status::Corruption(Substitute("Invalid packet: message had a length of $0, "
"but we only support messages up to $1 bytes\n",
total_len, std::numeric_limits<int32_t>::max()));
}
CodedInputStream in(buf.data(), buf.size());
// Protobuf enforces a 64MB total bytes limit on CodedInputStream by default.
// Override this default with the actual size of the buffer to allow messages
// larger than 64MB.
in.SetTotalBytesLimit(buf.size());
in.Skip(kMsgLengthPrefixLength);
uint32_t header_len;
if (PREDICT_FALSE(!in.ReadVarint32(&header_len))) {
return Status::Corruption("Invalid packet: missing header delimiter",
KUDU_REDACT(buf.ToDebugString()));
}
CodedInputStream::Limit l;
l = in.PushLimit(header_len);
if (PREDICT_FALSE(!parsed_header->ParseFromCodedStream(&in))) {
return Status::Corruption("Invalid packet: header too short",
KUDU_REDACT(buf.ToDebugString()));
}
in.PopLimit(l);
uint32_t main_msg_len;
if (PREDICT_FALSE(!in.ReadVarint32(&main_msg_len))) {
return Status::Corruption("Invalid packet: missing main msg length",
KUDU_REDACT(buf.ToDebugString()));
}
if (PREDICT_FALSE(!in.Skip(main_msg_len))) {
return Status::Corruption(
StringPrintf("Invalid packet: data too short, expected %d byte main_msg", main_msg_len),
KUDU_REDACT(buf.ToDebugString()));
}
if (PREDICT_FALSE(in.BytesUntilLimit() > 0)) {
return Status::Corruption(
StringPrintf("Invalid packet: %d extra bytes at end of packet", in.BytesUntilLimit()),
KUDU_REDACT(buf.ToDebugString()));
}
*parsed_main_message = Slice(buf.data() + buf.size() - main_msg_len,
main_msg_len);
return Status::OK();
}
void SerializeConnHeader(uint8_t* buf) {
memcpy(reinterpret_cast<char *>(buf), kMagicNumber, kMagicNumberLength);
buf += kMagicNumberLength;
buf[kHeaderPosVersion] = kCurrentRpcVersion;
buf[kHeaderPosServiceClass] = 0; // TODO: implement
buf[kHeaderPosAuthProto] = 0; // TODO: implement
}
// validate the entire rpc header (magic number + flags)
Status ValidateConnHeader(const Slice& slice) {
DCHECK_EQ(kMagicNumberLength + kHeaderFlagsLength, slice.size())
<< "Invalid RPC header length";
// validate actual magic
if (!slice.starts_with(kMagicNumber)) {
if (slice.starts_with("GET ") ||
slice.starts_with("POST") ||
slice.starts_with("HEAD")) {
return Status::InvalidArgument("invalid negotation, appears to be an HTTP client on "
"the RPC port");
}
return Status::InvalidArgument("connection must begin with magic number", kMagicNumber);
}
const uint8_t *data = slice.data();
data += kMagicNumberLength;
// validate version
if (data[kHeaderPosVersion] != kCurrentRpcVersion) {
return Status::InvalidArgument("Unsupported RPC version",
StringPrintf("Received: %d, Supported: %d",
data[kHeaderPosVersion], kCurrentRpcVersion));
}
// TODO: validate additional header flags:
// RPC_SERVICE_CLASS
// RPC_AUTH_PROTOCOL
return Status::OK();
}
} // namespace serialization
} // namespace rpc
} // namespace kudu