| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| |
| from __future__ import absolute_import |
| import datetime |
| from .packer import Packer |
| from .datatypes import serial, timestamp, RangedSet, Struct, UUID |
| from .ops import Compound, PRIMITIVE, COMPOUND |
| |
| try: |
| buffer |
| except NameError: |
| buffer = memoryview |
| |
| try: |
| long |
| except NameError: |
| long = int |
| |
| try: |
| unicode |
| except NameError: |
| unicode = str |
| |
| class CodecException(Exception): pass |
| |
| def direct(t): |
| return lambda x: t |
| |
| def map_str(s): |
| for c in s: |
| if ord(c) >= 0x80: |
| return "vbin16" |
| return "str16" |
| |
| class Codec(Packer): |
| |
| ENCODINGS = { |
| bool: direct("boolean"), |
| unicode: direct("str16"), |
| str: map_str, |
| buffer: direct("vbin32"), |
| int: direct("int64"), |
| long: direct("int64"), |
| float: direct("double"), |
| None.__class__: direct("void"), |
| list: direct("list"), |
| tuple: direct("list"), |
| dict: direct("map"), |
| timestamp: direct("datetime"), |
| datetime.datetime: direct("datetime"), |
| UUID: direct("uuid"), |
| Compound: direct("struct32") |
| } |
| |
| def encoding(self, obj): |
| enc = self._encoding(obj.__class__, obj) |
| if enc is None: |
| raise CodecException("no encoding for %r" % obj) |
| return PRIMITIVE[enc] |
| |
| def _encoding(self, klass, obj): |
| if klass in self.ENCODINGS: |
| return self.ENCODINGS[klass](obj) |
| for base in klass.__bases__: |
| result = self._encoding(base, obj) |
| if result != None: |
| return result |
| |
| def read_primitive(self, type): |
| return getattr(self, "read_%s" % type.NAME)() |
| def write_primitive(self, type, v): |
| getattr(self, "write_%s" % type.NAME)(v) |
| |
| def read_void(self): |
| return None |
| def write_void(self, v): |
| assert v == None |
| |
| def read_bit(self): |
| return True |
| def write_bit(self, b): |
| if not b: raise ValueError(b) |
| |
| def read_uint8(self): |
| return self.unpack("!B") |
| def write_uint8(self, n): |
| if n < 0 or n > 255: |
| raise CodecException("Cannot encode %d as uint8" % n) |
| return self.pack("!B", n) |
| |
| def read_int8(self): |
| return self.unpack("!b") |
| def write_int8(self, n): |
| if n < -128 or n > 127: |
| raise CodecException("Cannot encode %d as int8" % n) |
| self.pack("!b", n) |
| |
| def read_char(self): |
| return self.unpack("!c") |
| def write_char(self, c): |
| self.pack("!c", c) |
| |
| def read_boolean(self): |
| return self.read_uint8() != 0 |
| def write_boolean(self, b): |
| if b: n = 1 |
| else: n = 0 |
| self.write_uint8(n) |
| |
| |
| def read_uint16(self): |
| return self.unpack("!H") |
| def write_uint16(self, n): |
| if n < 0 or n > 65535: |
| raise CodecException("Cannot encode %d as uint16" % n) |
| self.pack("!H", n) |
| |
| def read_int16(self): |
| return self.unpack("!h") |
| def write_int16(self, n): |
| if n < -32768 or n > 32767: |
| raise CodecException("Cannot encode %d as int16" % n) |
| self.pack("!h", n) |
| |
| |
| def read_uint32(self): |
| return self.unpack("!L") |
| def write_uint32(self, n): |
| if n < 0 or n > 4294967295: |
| raise CodecException("Cannot encode %d as uint32" % n) |
| self.pack("!L", n) |
| |
| def read_int32(self): |
| return self.unpack("!l") |
| def write_int32(self, n): |
| if n < -2147483648 or n > 2147483647: |
| raise CodecException("Cannot encode %d as int32" % n) |
| self.pack("!l", n) |
| |
| def read_float(self): |
| return self.unpack("!f") |
| def write_float(self, f): |
| self.pack("!f", f) |
| |
| def read_sequence_no(self): |
| return serial(self.read_uint32()) |
| def write_sequence_no(self, n): |
| self.write_uint32(n.value) |
| |
| |
| def read_uint64(self): |
| return self.unpack("!Q") |
| def write_uint64(self, n): |
| self.pack("!Q", n) |
| |
| def read_int64(self): |
| return self.unpack("!q") |
| def write_int64(self, n): |
| self.pack("!q", n) |
| |
| def read_datetime(self): |
| return timestamp(self.read_uint64()) |
| def write_datetime(self, t): |
| if isinstance(t, datetime.datetime): |
| t = timestamp(t) |
| self.write_uint64(t) |
| |
| def read_double(self): |
| return self.unpack("!d") |
| def write_double(self, d): |
| self.pack("!d", d) |
| |
| def read_vbin8(self): |
| return self.read(self.read_uint8()) |
| def write_vbin8(self, b): |
| if isinstance(b, buffer): |
| b = str(b) |
| self.write_uint8(len(b)) |
| self.write(b) |
| |
| def read_str8(self): |
| return self.read_vbin8().decode("utf8") |
| def write_str8(self, s): |
| self.write_vbin8(s.encode("utf8")) |
| |
| def read_str16(self): |
| return self.read_vbin16().decode("utf8") |
| def write_str16(self, s): |
| self.write_vbin16(s.encode("utf8")) |
| |
| def read_str16_latin(self): |
| return self.read_vbin16().decode("iso-8859-15") |
| def write_str16_latin(self, s): |
| self.write_vbin16(s.encode("iso-8859-15")) |
| |
| |
| def read_vbin16(self): |
| return self.read(self.read_uint16()) |
| def write_vbin16(self, b): |
| if isinstance(b, buffer): |
| b = str(b) |
| self.write_uint16(len(b)) |
| self.write(b) |
| |
| def read_sequence_set(self): |
| result = RangedSet() |
| size = self.read_uint16() |
| nranges = size/8 |
| while nranges > 0: |
| lower = self.read_sequence_no() |
| upper = self.read_sequence_no() |
| result.add(lower, upper) |
| nranges -= 1 |
| return result |
| def write_sequence_set(self, ss): |
| size = 8*len(ss.ranges) |
| self.write_uint16(size) |
| for range in ss.ranges: |
| self.write_sequence_no(range.lower) |
| self.write_sequence_no(range.upper) |
| |
| def read_vbin32(self): |
| return self.read(self.read_uint32()) |
| def write_vbin32(self, b): |
| if isinstance(b, buffer): |
| b = str(b) |
| # Allow unicode values in connection 'response' field |
| if isinstance(b, unicode): |
| b = b.encode('utf8') |
| self.write_uint32(len(b)) |
| self.write(b) |
| |
| def read_map(self): |
| sc = StringCodec(self.read_vbin32()) |
| if not sc.encoded: |
| return None |
| count = sc.read_uint32() |
| result = {} |
| while sc.encoded: |
| k = sc.read_str8() |
| code = sc.read_uint8() |
| type = PRIMITIVE[code] |
| v = sc.read_primitive(type) |
| result[k] = v |
| return result |
| |
| def _write_map_elem(self, k, v): |
| type = self.encoding(v) |
| sc = StringCodec() |
| sc.write_str8(k) |
| sc.write_uint8(type.CODE) |
| sc.write_primitive(type, v) |
| return sc.encoded |
| |
| def write_map(self, m): |
| sc = StringCodec() |
| if m is not None: |
| sc.write_uint32(len(m)) |
| sc.write(b"".join(map(self._write_map_elem, m.keys(), m.values()))) |
| self.write_vbin32(sc.encoded) |
| |
| def read_array(self): |
| sc = StringCodec(self.read_vbin32()) |
| if not sc.encoded: |
| return None |
| type = PRIMITIVE[sc.read_uint8()] |
| count = sc.read_uint32() |
| result = [] |
| while count > 0: |
| result.append(sc.read_primitive(type)) |
| count -= 1 |
| return result |
| def write_array(self, a): |
| sc = StringCodec() |
| if a is not None: |
| if len(a) > 0: |
| type = self.encoding(a[0]) |
| else: |
| type = self.encoding(None) |
| sc.write_uint8(type.CODE) |
| sc.write_uint32(len(a)) |
| for o in a: |
| sc.write_primitive(type, o) |
| self.write_vbin32(sc.encoded) |
| |
| def read_list(self): |
| sc = StringCodec(self.read_vbin32()) |
| if not sc.encoded: |
| return None |
| count = sc.read_uint32() |
| result = [] |
| while count > 0: |
| type = PRIMITIVE[sc.read_uint8()] |
| result.append(sc.read_primitive(type)) |
| count -= 1 |
| return result |
| def write_list(self, l): |
| sc = StringCodec() |
| if l is not None: |
| sc.write_uint32(len(l)) |
| for o in l: |
| type = self.encoding(o) |
| sc.write_uint8(type.CODE) |
| sc.write_primitive(type, o) |
| self.write_vbin32(sc.encoded) |
| |
| def read_struct32(self): |
| size = self.read_uint32() |
| code = self.read_uint16() |
| cls = COMPOUND[code] |
| op = cls() |
| self.read_fields(op) |
| return op |
| def write_struct32(self, value): |
| self.write_compound(value) |
| |
| def read_compound(self, cls): |
| size = self.read_size(cls.SIZE) |
| if cls.CODE is not None: |
| code = self.read_uint16() |
| assert code == cls.CODE |
| op = cls() |
| self.read_fields(op) |
| return op |
| def write_compound(self, op): |
| sc = StringCodec() |
| if op.CODE is not None: |
| sc.write_uint16(op.CODE) |
| sc.write_fields(op) |
| self.write_size(op.SIZE, len(sc.encoded)) |
| self.write(sc.encoded) |
| |
| def read_fields(self, op): |
| flags = 0 |
| for i in range(op.PACK): |
| flags |= (self.read_uint8() << 8*i) |
| |
| for i in range(len(op.FIELDS)): |
| f = op.FIELDS[i] |
| if flags & (0x1 << i): |
| if f.type in COMPOUND: |
| value = self.read_compound(COMPOUND[f.type]) |
| else: |
| value = getattr(self, "read_%s" % f.type)() |
| setattr(op, f.name, value) |
| def write_fields(self, op): |
| flags = 0 |
| for i in range(len(op.FIELDS)): |
| f = op.FIELDS[i] |
| value = getattr(op, f.name) |
| if f.type == "bit": |
| present = value |
| else: |
| present = value != None |
| if present: |
| flags |= (0x1 << i) |
| for i in range(op.PACK): |
| self.write_uint8((flags >> 8*i) & 0xFF) |
| for i in range(len(op.FIELDS)): |
| f = op.FIELDS[i] |
| if flags & (0x1 << i): |
| if f.type in COMPOUND: |
| enc = self.write_compound |
| else: |
| enc = getattr(self, "write_%s" % f.type) |
| value = getattr(op, f.name) |
| enc(value) |
| |
| def read_size(self, width): |
| if width > 0: |
| attr = "read_uint%d" % (width*8) |
| return getattr(self, attr)() |
| def write_size(self, width, n): |
| if width > 0: |
| attr = "write_uint%d" % (width*8) |
| getattr(self, attr)(n) |
| |
| def read_uuid(self): |
| return UUID(bytes=self.unpack("16s")) |
| def write_uuid(self, s): |
| if isinstance(s, UUID): |
| s = s.bytes |
| self.pack("16s", s) |
| |
| def read_bin128(self): |
| return self.unpack("16s") |
| def write_bin128(self, b): |
| self.pack("16s", b) |
| |
| |
| |
| class StringCodec(Codec): |
| |
| def __init__(self, encoded = ""): |
| self.encoded = encoded |
| |
| def read(self, n): |
| result = self.encoded[:n] |
| self.encoded = self.encoded[n:] |
| return result |
| |
| def write(self, s): |
| self.encoded += s |