| /* |
| * |
| * Copyright 2018 gRPC authors. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| * |
| */ |
| |
| // Package conn contains an implementation of a secure channel created by gRPC |
| // handshakers. |
| package conn |
| |
| import ( |
| "encoding/binary" |
| "fmt" |
| "math" |
| "net" |
| |
| core "google.golang.org/grpc/credentials/alts/internal" |
| ) |
| |
| // ALTSRecordCrypto is the interface for gRPC ALTS record protocol. |
| type ALTSRecordCrypto interface { |
| // Encrypt encrypts the plaintext and computes the tag (if any) of dst |
| // and plaintext, dst and plaintext do not overlap. |
| Encrypt(dst, plaintext []byte) ([]byte, error) |
| // EncryptionOverhead returns the tag size (if any) in bytes. |
| EncryptionOverhead() int |
| // Decrypt decrypts ciphertext and verify the tag (if any). dst and |
| // ciphertext may alias exactly or not at all. To reuse ciphertext's |
| // storage for the decrypted output, use ciphertext[:0] as dst. |
| Decrypt(dst, ciphertext []byte) ([]byte, error) |
| } |
| |
| // ALTSRecordFunc is a function type for factory functions that create |
| // ALTSRecordCrypto instances. |
| type ALTSRecordFunc func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) |
| |
| const ( |
| // MsgLenFieldSize is the byte size of the frame length field of a |
| // framed message. |
| MsgLenFieldSize = 4 |
| // The byte size of the message type field of a framed message. |
| msgTypeFieldSize = 4 |
| // The bytes size limit for a ALTS record message. |
| altsRecordLengthLimit = 1024 * 1024 // 1 MiB |
| // The default bytes size of a ALTS record message. |
| altsRecordDefaultLength = 4 * 1024 // 4KiB |
| // Message type value included in ALTS record framing. |
| altsRecordMsgType = uint32(0x06) |
| // The initial write buffer size. |
| altsWriteBufferInitialSize = 32 * 1024 // 32KiB |
| // The maximum write buffer size. This *must* be multiple of |
| // altsRecordDefaultLength. |
| altsWriteBufferMaxSize = 512 * 1024 // 512KiB |
| ) |
| |
| var ( |
| protocols = make(map[string]ALTSRecordFunc) |
| ) |
| |
| // RegisterProtocol register a ALTS record encryption protocol. |
| func RegisterProtocol(protocol string, f ALTSRecordFunc) error { |
| if _, ok := protocols[protocol]; ok { |
| return fmt.Errorf("protocol %v is already registered", protocol) |
| } |
| protocols[protocol] = f |
| return nil |
| } |
| |
| // conn represents a secured connection. It implements the net.Conn interface. |
| type conn struct { |
| net.Conn |
| crypto ALTSRecordCrypto |
| // buf holds data that has been read from the connection and decrypted, |
| // but has not yet been returned by Read. |
| buf []byte |
| payloadLengthLimit int |
| // protected holds data read from the network but have not yet been |
| // decrypted. This data might not compose a complete frame. |
| protected []byte |
| // writeBuf is a buffer used to contain encrypted frames before being |
| // written to the network. |
| writeBuf []byte |
| // nextFrame stores the next frame (in protected buffer) info. |
| nextFrame []byte |
| // overhead is the calculated overhead of each frame. |
| overhead int |
| } |
| |
| // NewConn creates a new secure channel instance given the other party role and |
| // handshaking result. |
| func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, protected []byte) (net.Conn, error) { |
| newCrypto := protocols[recordProtocol] |
| if newCrypto == nil { |
| return nil, fmt.Errorf("negotiated unknown next_protocol %q", recordProtocol) |
| } |
| crypto, err := newCrypto(side, key) |
| if err != nil { |
| return nil, fmt.Errorf("protocol %q: %v", recordProtocol, err) |
| } |
| overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead() |
| payloadLengthLimit := altsRecordDefaultLength - overhead |
| if protected == nil { |
| // We pre-allocate protected to be of size |
| // 2*altsRecordDefaultLength-1 during initialization. We only |
| // read from the network into protected when protected does not |
| // contain a complete frame, which is at most |
| // altsRecordDefaultLength-1 (bytes). And we read at most |
| // altsRecordDefaultLength (bytes) data into protected at one |
| // time. Therefore, 2*altsRecordDefaultLength-1 is large enough |
| // to buffer data read from the network. |
| protected = make([]byte, 0, 2*altsRecordDefaultLength-1) |
| } |
| |
| altsConn := &conn{ |
| Conn: c, |
| crypto: crypto, |
| payloadLengthLimit: payloadLengthLimit, |
| protected: protected, |
| writeBuf: make([]byte, altsWriteBufferInitialSize), |
| nextFrame: protected, |
| overhead: overhead, |
| } |
| return altsConn, nil |
| } |
| |
| // Read reads and decrypts a frame from the underlying connection, and copies the |
| // decrypted payload into b. If the size of the payload is greater than len(b), |
| // Read retains the remaining bytes in an internal buffer, and subsequent calls |
| // to Read will read from this buffer until it is exhausted. |
| func (p *conn) Read(b []byte) (n int, err error) { |
| if len(p.buf) == 0 { |
| var framedMsg []byte |
| framedMsg, p.nextFrame, err = ParseFramedMsg(p.nextFrame, altsRecordLengthLimit) |
| if err != nil { |
| return n, err |
| } |
| // Check whether the next frame to be decrypted has been |
| // completely received yet. |
| if len(framedMsg) == 0 { |
| copy(p.protected, p.nextFrame) |
| p.protected = p.protected[:len(p.nextFrame)] |
| // Always copy next incomplete frame to the beginning of |
| // the protected buffer and reset nextFrame to it. |
| p.nextFrame = p.protected |
| } |
| // Check whether a complete frame has been received yet. |
| for len(framedMsg) == 0 { |
| if len(p.protected) == cap(p.protected) { |
| tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength) |
| copy(tmp, p.protected) |
| p.protected = tmp |
| } |
| n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)]) |
| if err != nil { |
| return 0, err |
| } |
| p.protected = p.protected[:len(p.protected)+n] |
| framedMsg, p.nextFrame, err = ParseFramedMsg(p.protected, altsRecordLengthLimit) |
| if err != nil { |
| return 0, err |
| } |
| } |
| // Now we have a complete frame, decrypted it. |
| msg := framedMsg[MsgLenFieldSize:] |
| msgType := binary.LittleEndian.Uint32(msg[:msgTypeFieldSize]) |
| if msgType&0xff != altsRecordMsgType { |
| return 0, fmt.Errorf("received frame with incorrect message type %v, expected lower byte %v", |
| msgType, altsRecordMsgType) |
| } |
| ciphertext := msg[msgTypeFieldSize:] |
| |
| // Decrypt requires that if the dst and ciphertext alias, they |
| // must alias exactly. Code here used to use msg[:0], but msg |
| // starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than |
| // ciphertext, so they alias inexactly. Using ciphertext[:0] |
| // arranges the appropriate aliasing without needing to copy |
| // ciphertext or use a separate destination buffer. For more info |
| // check: https://golang.org/pkg/crypto/cipher/#AEAD. |
| p.buf, err = p.crypto.Decrypt(ciphertext[:0], ciphertext) |
| if err != nil { |
| return 0, err |
| } |
| } |
| |
| n = copy(b, p.buf) |
| p.buf = p.buf[n:] |
| return n, nil |
| } |
| |
| // Write encrypts, frames, and writes bytes from b to the underlying connection. |
| func (p *conn) Write(b []byte) (n int, err error) { |
| n = len(b) |
| // Calculate the output buffer size with framing and encryption overhead. |
| numOfFrames := int(math.Ceil(float64(len(b)) / float64(p.payloadLengthLimit))) |
| size := len(b) + numOfFrames*p.overhead |
| // If writeBuf is too small, increase its size up to the maximum size. |
| partialBSize := len(b) |
| if size > altsWriteBufferMaxSize { |
| size = altsWriteBufferMaxSize |
| const numOfFramesInMaxWriteBuf = altsWriteBufferMaxSize / altsRecordDefaultLength |
| partialBSize = numOfFramesInMaxWriteBuf * p.payloadLengthLimit |
| } |
| if len(p.writeBuf) < size { |
| p.writeBuf = make([]byte, size) |
| } |
| |
| for partialBStart := 0; partialBStart < len(b); partialBStart += partialBSize { |
| partialBEnd := partialBStart + partialBSize |
| if partialBEnd > len(b) { |
| partialBEnd = len(b) |
| } |
| partialB := b[partialBStart:partialBEnd] |
| writeBufIndex := 0 |
| for len(partialB) > 0 { |
| payloadLen := len(partialB) |
| if payloadLen > p.payloadLengthLimit { |
| payloadLen = p.payloadLengthLimit |
| } |
| buf := partialB[:payloadLen] |
| partialB = partialB[payloadLen:] |
| |
| // Write buffer contains: length, type, payload, and tag |
| // if any. |
| |
| // 1. Fill in type field. |
| msg := p.writeBuf[writeBufIndex+MsgLenFieldSize:] |
| binary.LittleEndian.PutUint32(msg, altsRecordMsgType) |
| |
| // 2. Encrypt the payload and create a tag if any. |
| msg, err = p.crypto.Encrypt(msg[:msgTypeFieldSize], buf) |
| if err != nil { |
| return n, err |
| } |
| |
| // 3. Fill in the size field. |
| binary.LittleEndian.PutUint32(p.writeBuf[writeBufIndex:], uint32(len(msg))) |
| |
| // 4. Increase writeBufIndex. |
| writeBufIndex += len(buf) + p.overhead |
| } |
| nn, err := p.Conn.Write(p.writeBuf[:writeBufIndex]) |
| if err != nil { |
| // We need to calculate the actual data size that was |
| // written. This means we need to remove header, |
| // encryption overheads, and any partially-written |
| // frame data. |
| numOfWrittenFrames := int(math.Floor(float64(nn) / float64(altsRecordDefaultLength))) |
| return partialBStart + numOfWrittenFrames*p.payloadLengthLimit, err |
| } |
| } |
| return n, nil |
| } |
| |
| func min(a, b int) int { |
| if a < b { |
| return a |
| } |
| return b |
| } |