blob: 13f532ecff25ca6e83572591d812d5373749dc72 [file] [log] [blame]
// Copyright 2021-2023 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package triple_protocol
import (
"bytes"
"encoding/binary"
"errors"
"io"
)
// flagEnvelopeCompressed indicates that the data is compressed. It has the
// same meaning in the gRPC-Web, gRPC-HTTP2, and Connect protocols.
const flagEnvelopeCompressed = 0b00000001
var errSpecialEnvelope = errorf(
CodeUnknown,
"final message has protocol-specific flags: %w",
// User code checks for end of stream with errors.Is(err, io.EOF).
io.EOF,
)
// envelope is a block of arbitrary bytes wrapped in gRPC and Connect's framing
// protocol.
//
// Each message is preceded by a 5-byte prefix. The first byte is a uint8 used
// as a set of bitwise flags, and the remainder is a uint32 indicating the
// message length. gRPC and Connect interpret the bitwise flags differently, so
// envelope leaves their interpretation up to the caller.
type envelope struct {
Data *bytes.Buffer
Flags uint8
}
func (e *envelope) IsSet(flag uint8) bool {
return e.Flags&flag == flag
}
type envelopeWriter struct {
writer io.Writer
codec Codec
compressMinBytes int
compressionPool *compressionPool
bufferPool *bufferPool
sendMaxBytes int
}
// marshal and write to socket
func (w *envelopeWriter) Marshal(message any) *Error {
if message == nil {
if _, err := w.writer.Write(nil); err != nil {
if connectErr, ok := asError(err); ok {
return connectErr
}
return NewError(CodeUnknown, err)
}
return nil
}
raw, err := w.codec.Marshal(message)
if err != nil {
return errorf(CodeInternal, "marshal message: %w", err)
}
// We can't avoid allocating the byte slice, so we may as well reuse it once
// we're done with it.
buffer := bytes.NewBuffer(raw)
defer w.bufferPool.Put(buffer)
envelope := &envelope{Data: buffer}
return w.Write(envelope)
}
// Write writes the enveloped message, compressing as necessary. It doesn't
// retain any references to the supplied envelope or its underlying data.
// so we can reuse it.
func (w *envelopeWriter) Write(env *envelope) *Error {
// compressed || there is no compressionPool || there is no need to compress
if env.IsSet(flagEnvelopeCompressed) ||
w.compressionPool == nil ||
env.Data.Len() < w.compressMinBytes {
if w.sendMaxBytes > 0 && env.Data.Len() > w.sendMaxBytes {
return errorf(CodeResourceExhausted, "message size %d exceeds sendMaxBytes %d", env.Data.Len(), w.sendMaxBytes)
}
// write to socket
return w.write(env)
}
data := w.bufferPool.Get()
defer w.bufferPool.Put(data)
if err := w.compressionPool.Compress(data, env.Data); err != nil {
return err
}
if w.sendMaxBytes > 0 && data.Len() > w.sendMaxBytes {
return errorf(CodeResourceExhausted, "compressed message size %d exceeds sendMaxBytes %d", data.Len(), w.sendMaxBytes)
}
return w.write(&envelope{
Data: data,
Flags: env.Flags | flagEnvelopeCompressed,
})
}
func (w *envelopeWriter) write(env *envelope) *Error {
prefix := [5]byte{}
prefix[0] = env.Flags
binary.BigEndian.PutUint32(prefix[1:5], uint32(env.Data.Len()))
if _, err := w.writer.Write(prefix[:]); err != nil {
if connectErr, ok := asError(err); ok {
return connectErr
}
return errorf(CodeUnknown, "write envelope: %w", err)
}
if _, err := io.Copy(w.writer, env.Data); err != nil {
return errorf(CodeUnknown, "write message: %w", err)
}
return nil
}
type envelopeReader struct {
reader io.Reader
codec Codec
last envelope
compressionPool *compressionPool
bufferPool *bufferPool
readMaxBytes int
}
// Unmarshal reads entire envelope and uses codec to unmarshal
func (r *envelopeReader) Unmarshal(message any) *Error {
buffer := r.bufferPool.Get()
defer r.bufferPool.Put(buffer)
env := &envelope{Data: buffer}
err := r.Read(env)
switch {
case err == nil &&
(env.Flags == 0 || env.Flags == flagEnvelopeCompressed) &&
env.Data.Len() == 0:
// This is a standard message (because none of the top 7 bits are set) and
// there's no data, so the zero value of the message is correct.
return nil
case err != nil && errors.Is(err, io.EOF):
// The stream has ended. Propagate the EOF to the caller.
return err
case err != nil:
// Something's wrong.
return err
}
data := env.Data
if data.Len() > 0 && env.IsSet(flagEnvelopeCompressed) {
if r.compressionPool == nil {
return errorf(
CodeInvalidArgument,
"gRPC protocol error: sent compressed message without Grpc-Encoding header",
)
}
decompressed := r.bufferPool.Get()
defer r.bufferPool.Put(decompressed)
if err := r.compressionPool.Decompress(decompressed, data, int64(r.readMaxBytes)); err != nil {
return err
}
data = decompressed
}
if env.Flags != 0 && env.Flags != flagEnvelopeCompressed {
// One of the protocol-specific flags are set, so this is the end of the
// stream. Save the message for protocol-specific code to process and
// return a sentinel error. Since we've deferred functions to return env's
// underlying buffer to a pool, we need to keep a copy.
r.last = envelope{
Data: r.bufferPool.Get(),
Flags: env.Flags,
}
// Don't return last to the pool! We're going to reference the data
// elsewhere.
if _, err := r.last.Data.ReadFrom(data); err != nil {
return errorf(CodeUnknown, "copy final envelope: %w", err)
}
return errSpecialEnvelope
}
if err := r.codec.Unmarshal(data.Bytes(), message); err != nil {
return errorf(CodeInvalidArgument, "unmarshal into %T: %w", message, err)
}
return nil
}
func (r *envelopeReader) Read(env *envelope) *Error {
// Read prefix firstly, then read the packet with length specified by length field
prefixes := [5]byte{}
prefixBytesRead, err := r.reader.Read(prefixes[:])
switch {
case (err == nil || errors.Is(err, io.EOF)) &&
prefixBytesRead == 5 &&
isSizeZeroPrefix(prefixes):
// Successfully read prefix and expect no additional data.
env.Flags = prefixes[0]
return nil
case err != nil && errors.Is(err, io.EOF) && prefixBytesRead == 0:
// The stream ended cleanly. That's expected, but we need to propagate them
// to the user so that they know that the stream has ended. We shouldn't
// add any alarming text about protocol errors, though.
return NewError(CodeUnknown, err)
case err != nil || prefixBytesRead < 5:
// Something else has gone wrong - the stream didn't end cleanly.
if connectErr, ok := asError(err); ok {
return connectErr
}
if maxBytesErr := asMaxBytesError(err, "read 5 byte message prefix"); maxBytesErr != nil {
// We're reading from an http.MaxBytesHandler, and we've exceeded the read limit.
return maxBytesErr
}
return errorf(
CodeInvalidArgument,
"protocol error: incomplete envelope: %w", err,
)
}
size := int(binary.BigEndian.Uint32(prefixes[1:5]))
if size < 0 {
return errorf(CodeInvalidArgument, "message size %d overflowed uint32", size)
}
if r.readMaxBytes > 0 && size > r.readMaxBytes {
_, err := io.CopyN(io.Discard, r.reader, int64(size))
if err != nil && !errors.Is(err, io.EOF) {
return errorf(CodeUnknown, "read enveloped message: %w", err)
}
return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", size, r.readMaxBytes)
}
if size > 0 {
env.Data.Grow(size)
// At layer 7, we don't know exactly what's happening down in L4. Large
// length-prefixed messages may arrive in chunks, so we may need to read
// the request body past EOF. We also need to take care that we don't retry
// forever if the message is malformed.
remaining := int64(size)
for remaining > 0 {
bytesRead, err := io.CopyN(env.Data, r.reader, remaining)
if err != nil && !errors.Is(err, io.EOF) {
if maxBytesErr := asMaxBytesError(err, "read %d byte message", size); maxBytesErr != nil {
// We're reading from an http.MaxBytesHandler, and we've exceeded the read limit.
return maxBytesErr
}
return errorf(CodeUnknown, "read enveloped message: %w", err)
}
if errors.Is(err, io.EOF) && bytesRead == 0 {
// We've gotten zero-length chunk of data. Message is likely malformed,
// don't wait for additional chunks.
return errorf(
CodeInvalidArgument,
"protocol error: promised %d bytes in enveloped message, got %d bytes",
size,
int64(size)-remaining,
)
}
remaining -= bytesRead
}
}
env.Flags = prefixes[0]
return nil
}
func isSizeZeroPrefix(prefix [5]byte) bool {
for i := 1; i < 5; i++ {
if prefix[i] != 0 {
return false
}
}
return true
}