blob: 2d7a107566b11e76b6c561970595c77ae1b20211 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import zlib
from cassandra import DriverException
from cassandra.marshal import int32_pack
from cassandra.protocol import write_uint_le, read_uint_le
CRC24_INIT = 0x875060
CRC24_POLY = 0x1974F0B
CRC24_LENGTH = 3
CRC32_LENGTH = 4
CRC32_INITIAL = zlib.crc32(b"\xfa\x2d\x55\xca")
class CrcException(Exception):
"""
CRC mismatch error.
TODO: here to avoid import cycles with cassandra.connection. In the next
major, the exceptions should be declared in a separated exceptions.py
file.
"""
pass
def compute_crc24(data, length):
crc = CRC24_INIT
for _ in range(length):
crc ^= (data & 0xff) << 16
data >>= 8
for i in range(8):
crc <<= 1
if crc & 0x1000000 != 0:
crc ^= CRC24_POLY
return crc
def compute_crc32(data, value):
crc32 = zlib.crc32(data, value)
return crc32
class SegmentHeader(object):
payload_length = None
uncompressed_payload_length = None
is_self_contained = None
def __init__(self, payload_length, uncompressed_payload_length, is_self_contained):
self.payload_length = payload_length
self.uncompressed_payload_length = uncompressed_payload_length
self.is_self_contained = is_self_contained
@property
def segment_length(self):
"""
Return the total length of the segment, including the CRC.
"""
hl = SegmentCodec.UNCOMPRESSED_HEADER_LENGTH if self.uncompressed_payload_length < 1 \
else SegmentCodec.COMPRESSED_HEADER_LENGTH
return hl + CRC24_LENGTH + self.payload_length + CRC32_LENGTH
class Segment(object):
MAX_PAYLOAD_LENGTH = 128 * 1024 - 1
payload = None
is_self_contained = None
def __init__(self, payload, is_self_contained):
self.payload = payload
self.is_self_contained = is_self_contained
class SegmentCodec(object):
COMPRESSED_HEADER_LENGTH = 5
UNCOMPRESSED_HEADER_LENGTH = 3
FLAG_OFFSET = 17
compressor = None
decompressor = None
def __init__(self, compressor=None, decompressor=None):
self.compressor = compressor
self.decompressor = decompressor
@property
def header_length(self):
return self.COMPRESSED_HEADER_LENGTH if self.compression \
else self.UNCOMPRESSED_HEADER_LENGTH
@property
def header_length_with_crc(self):
return (self.COMPRESSED_HEADER_LENGTH if self.compression
else self.UNCOMPRESSED_HEADER_LENGTH) + CRC24_LENGTH
@property
def compression(self):
return self.compressor and self.decompressor
def compress(self, data):
# the uncompressed length is already encoded in the header, so
# we remove it here
return self.compressor(data)[4:]
def decompress(self, encoded_data, uncompressed_length):
return self.decompressor(int32_pack(uncompressed_length) + encoded_data)
def encode_header(self, buffer, payload_length, uncompressed_length, is_self_contained):
if payload_length > Segment.MAX_PAYLOAD_LENGTH:
raise DriverException('Payload length exceed Segment.MAX_PAYLOAD_LENGTH')
header_data = payload_length
flag_offset = self.FLAG_OFFSET
if self.compression:
header_data |= uncompressed_length << flag_offset
flag_offset += 17
if is_self_contained:
header_data |= 1 << flag_offset
write_uint_le(buffer, header_data, size=self.header_length)
header_crc = compute_crc24(header_data, self.header_length)
write_uint_le(buffer, header_crc, size=CRC24_LENGTH)
def _encode_segment(self, buffer, payload, is_self_contained):
"""
Encode a message to a single segment.
"""
uncompressed_payload = payload
uncompressed_payload_length = len(payload)
if self.compression:
compressed_payload = self.compress(uncompressed_payload)
if len(compressed_payload) >= uncompressed_payload_length:
encoded_payload = uncompressed_payload
uncompressed_payload_length = 0
else:
encoded_payload = compressed_payload
else:
encoded_payload = uncompressed_payload
payload_length = len(encoded_payload)
self.encode_header(buffer, payload_length, uncompressed_payload_length, is_self_contained)
payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL)
buffer.write(encoded_payload)
write_uint_le(buffer, payload_crc)
def encode(self, buffer, msg):
"""
Encode a message to one of more segments.
"""
msg_length = len(msg)
if msg_length > Segment.MAX_PAYLOAD_LENGTH:
payloads = []
for i in range(0, msg_length, Segment.MAX_PAYLOAD_LENGTH):
payloads.append(msg[i:i + Segment.MAX_PAYLOAD_LENGTH])
else:
payloads = [msg]
is_self_contained = len(payloads) == 1
for payload in payloads:
self._encode_segment(buffer, payload, is_self_contained)
def decode_header(self, buffer):
header_data = read_uint_le(buffer, self.header_length)
expected_header_crc = read_uint_le(buffer, CRC24_LENGTH)
actual_header_crc = compute_crc24(header_data, self.header_length)
if actual_header_crc != expected_header_crc:
raise CrcException('CRC mismatch on header {:x}. Received {:x}", computed {:x}.'.format(
header_data, expected_header_crc, actual_header_crc))
payload_length = header_data & Segment.MAX_PAYLOAD_LENGTH
header_data >>= 17
if self.compression:
uncompressed_payload_length = header_data & Segment.MAX_PAYLOAD_LENGTH
header_data >>= 17
else:
uncompressed_payload_length = -1
is_self_contained = (header_data & 1) == 1
return SegmentHeader(payload_length, uncompressed_payload_length, is_self_contained)
def decode(self, buffer, header):
encoded_payload = buffer.read(header.payload_length)
expected_payload_crc = read_uint_le(buffer)
actual_payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL)
if actual_payload_crc != expected_payload_crc:
raise CrcException('CRC mismatch on payload. Received {:x}", computed {:x}.'.format(
expected_payload_crc, actual_payload_crc))
payload = encoded_payload
if self.compression and header.uncompressed_payload_length > 0:
payload = self.decompress(encoded_payload, header.uncompressed_payload_length)
return Segment(payload, header.is_self_contained)