| // Copyright 2021-2023 Buf Technologies, Inc. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package triple_protocol |
| |
| import ( |
| "bytes" |
| "errors" |
| "io" |
| "math" |
| "strings" |
| "sync" |
| ) |
| |
| const ( |
| compressionGzip = "gzip" |
| compressionIdentity = "identity" |
| ) |
| |
| // A Decompressor is a reusable wrapper that decompresses an underlying data |
| // source. The standard library's [*gzip.Reader] implements Decompressor. |
| type Decompressor interface { |
| io.Reader |
| |
| // Close closes the Decompressor, but not the underlying data source. It may |
| // return an error if the Decompressor wasn't read to EOF. |
| Close() error |
| |
| // Reset discards the Decompressor's internal state, if any, and prepares it |
| // to read from a new source of compressed data. |
| Reset(io.Reader) error |
| } |
| |
| // A Compressor is a reusable wrapper that compresses data written to an |
| // underlying sink. The standard library's [*gzip.Writer] implements Compressor. |
| type Compressor interface { |
| io.Writer |
| |
| // Close flushes any buffered data to the underlying sink, then closes the |
| // Compressor. It must not close the underlying sink. |
| Close() error |
| |
| // Reset discards the Compressor's internal state, if any, and prepares it to |
| // write compressed data to a new sink. |
| Reset(io.Writer) |
| } |
| |
| type compressionPool struct { |
| decompressors sync.Pool |
| compressors sync.Pool |
| } |
| |
| func newCompressionPool( |
| newDecompressor func() Decompressor, |
| newCompressor func() Compressor, |
| ) *compressionPool { |
| if newDecompressor == nil && newCompressor == nil { |
| return nil |
| } |
| return &compressionPool{ |
| decompressors: sync.Pool{ |
| New: func() any { return newDecompressor() }, |
| }, |
| compressors: sync.Pool{ |
| New: func() any { return newCompressor() }, |
| }, |
| } |
| } |
| |
| func (c *compressionPool) Decompress(dst *bytes.Buffer, src *bytes.Buffer, readMaxBytes int64) *Error { |
| decompressor, err := c.getDecompressor(src) |
| if err != nil { |
| return errorf(CodeInvalidArgument, "get decompressor: %w", err) |
| } |
| reader := io.Reader(decompressor) |
| if readMaxBytes > 0 && readMaxBytes < math.MaxInt64 { |
| reader = io.LimitReader(decompressor, readMaxBytes+1) |
| } |
| bytesRead, err := dst.ReadFrom(reader) |
| if err != nil { |
| _ = c.putDecompressor(decompressor) |
| return errorf(CodeInvalidArgument, "decompress: %w", err) |
| } |
| if readMaxBytes > 0 && bytesRead > readMaxBytes { |
| discardedBytes, err := io.Copy(io.Discard, decompressor) |
| _ = c.putDecompressor(decompressor) |
| if err != nil { |
| return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", readMaxBytes, err) |
| } |
| return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", bytesRead+discardedBytes, readMaxBytes) |
| } |
| if err := c.putDecompressor(decompressor); err != nil { |
| return errorf(CodeUnknown, "recycle decompressor: %w", err) |
| } |
| return nil |
| } |
| |
| func (c *compressionPool) Compress(dst *bytes.Buffer, src *bytes.Buffer) *Error { |
| compressor, err := c.getCompressor(dst) |
| if err != nil { |
| return errorf(CodeUnknown, "get compressor: %w", err) |
| } |
| if _, err := io.Copy(compressor, src); err != nil { |
| _ = c.putCompressor(compressor) |
| return errorf(CodeInternal, "compress: %w", err) |
| } |
| if err := c.putCompressor(compressor); err != nil { |
| return errorf(CodeInternal, "recycle compressor: %w", err) |
| } |
| return nil |
| } |
| |
| func (c *compressionPool) getDecompressor(reader io.Reader) (Decompressor, error) { |
| decompressor, ok := c.decompressors.Get().(Decompressor) |
| if !ok { |
| return nil, errors.New("expected Decompressor, got incorrect type from pool") |
| } |
| return decompressor, decompressor.Reset(reader) |
| } |
| |
| func (c *compressionPool) putDecompressor(decompressor Decompressor) error { |
| if err := decompressor.Close(); err != nil { |
| return err |
| } |
| // While it's in the pool, we don't want the decompressor to retain a |
| // reference to the underlying reader. However, most decompressors attempt to |
| // read some header data from the new data source when Reset; since we don't |
| // know the compression format, we can't provide a valid header. Since we |
| // also reset the decompressor when it's pulled out of the pool, we can |
| // ignore errors here. |
| _ = decompressor.Reset(strings.NewReader("")) |
| c.decompressors.Put(decompressor) |
| return nil |
| } |
| |
| func (c *compressionPool) getCompressor(writer io.Writer) (Compressor, error) { |
| compressor, ok := c.compressors.Get().(Compressor) |
| if !ok { |
| return nil, errors.New("expected Compressor, got incorrect type from pool") |
| } |
| compressor.Reset(writer) |
| return compressor, nil |
| } |
| |
| func (c *compressionPool) putCompressor(compressor Compressor) error { |
| if err := compressor.Close(); err != nil { |
| return err |
| } |
| compressor.Reset(io.Discard) // don't keep references |
| c.compressors.Put(compressor) |
| return nil |
| } |
| |
| // readOnlyCompressionPools is a read-only interface to a map of named |
| // compressionPools. |
| type readOnlyCompressionPools interface { |
| Get(string) *compressionPool |
| Contains(string) bool |
| // Wordy, but clarifies how this is different from readOnlyCodecs.Names(). |
| CommaSeparatedNames() string |
| } |
| |
| func newReadOnlyCompressionPools( |
| nameToPool map[string]*compressionPool, |
| reversedNames []string, |
| ) readOnlyCompressionPools { |
| // Client and handler configs keep compression names in registration order, |
| // but we want the last registered to be the most preferred. |
| names := make([]string, 0, len(reversedNames)) |
| seen := make(map[string]struct{}, len(reversedNames)) |
| for i := len(reversedNames) - 1; i >= 0; i-- { |
| name := reversedNames[i] |
| if _, ok := seen[name]; ok { |
| continue |
| } |
| seen[name] = struct{}{} |
| names = append(names, name) |
| } |
| return &namedCompressionPools{ |
| nameToPool: nameToPool, |
| commaSeparatedNames: strings.Join(names, ","), |
| } |
| } |
| |
| type namedCompressionPools struct { |
| nameToPool map[string]*compressionPool |
| commaSeparatedNames string |
| } |
| |
| func (m *namedCompressionPools) Get(name string) *compressionPool { |
| if name == "" || name == compressionIdentity { |
| return nil |
| } |
| return m.nameToPool[name] |
| } |
| |
| func (m *namedCompressionPools) Contains(name string) bool { |
| _, ok := m.nameToPool[name] |
| return ok |
| } |
| |
| func (m *namedCompressionPools) CommaSeparatedNames() string { |
| return m.commaSeparatedNames |
| } |