blob: b910ba56bceca414348bb85cd0da4fcf9fc938db [file] [log] [blame]
#!/usr/bin/env python
##
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Input/Output utilities, including:
* i/o-specific constants
* i/o-specific exceptions
* schema validation
* leaf value encoding and decoding
* datum reader/writer stuff (?)
Also includes a generic representation for data, which
uses the following mapping:
* Schema records are implemented as dict.
* Schema arrays are implemented as list.
* Schema maps are implemented as dict.
* Schema strings are implemented as unicode.
* Schema bytes are implemented as str.
* Schema ints are implemented as int.
* Schema longs are implemented as long.
* Schema floats are implemented as float.
* Schema doubles are implemented as float.
* Schema booleans are implemented as bool.
"""
from __future__ import absolute_import, division, print_function
import datetime
import json
import struct
import sys
from decimal import Decimal, getcontext
from struct import Struct
from avro import constants, schema, timezones
try:
unicode
except NameError:
unicode = str
try:
basestring # type: ignore
except NameError:
basestring = (bytes, unicode)
try:
long
except NameError:
long = int
#
# Constants
#
_DEBUG_VALIDATE_INDENT = 0
_DEBUG_VALIDATE = False
INT_MIN_VALUE = -(1 << 31)
INT_MAX_VALUE = (1 << 31) - 1
LONG_MIN_VALUE = -(1 << 63)
LONG_MAX_VALUE = (1 << 63) - 1
# TODO(hammer): shouldn't ! be < for little-endian (according to spec?)
STRUCT_FLOAT = Struct('<f') # big-endian float
STRUCT_DOUBLE = Struct('<d') # big-endian double
STRUCT_SIGNED_SHORT = Struct('>h') # big-endian signed short
STRUCT_SIGNED_INT = Struct('>i') # big-endian signed int
STRUCT_SIGNED_LONG = Struct('>q') # big-endian signed long
#
# Exceptions
#
class AvroTypeException(schema.AvroException):
"""Raised when datum is not an example of schema."""
def __init__(self, expected_schema, datum):
pretty_expected = json.dumps(json.loads(str(expected_schema)), indent=2)
fail_msg = "The datum %s is not an example of the schema %s"\
% (datum, pretty_expected)
schema.AvroException.__init__(self, fail_msg)
class SchemaResolutionException(schema.AvroException):
def __init__(self, fail_msg, writers_schema=None, readers_schema=None):
pretty_writers = json.dumps(json.loads(str(writers_schema)), indent=2)
pretty_readers = json.dumps(json.loads(str(readers_schema)), indent=2)
if writers_schema:
fail_msg += "\nWriter's Schema: %s" % pretty_writers
if readers_schema:
fail_msg += "\nReader's Schema: %s" % pretty_readers
schema.AvroException.__init__(self, fail_msg)
#
# Validate
#
def _is_timezone_aware_datetime(dt):
return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None
_valid = {
'null': lambda s, d: d is None,
'boolean': lambda s, d: isinstance(d, bool),
'string': lambda s, d: isinstance(d, unicode),
'bytes': lambda s, d: ((isinstance(d, bytes)) or
(isinstance(d, Decimal) and
getattr(s, 'logical_type', None) == constants.DECIMAL)),
'int': lambda s, d: ((isinstance(d, (int, long))) and (INT_MIN_VALUE <= d <= INT_MAX_VALUE) or
(isinstance(d, datetime.date) and
getattr(s, 'logical_type', None) == constants.DATE) or
(isinstance(d, datetime.time) and
getattr(s, 'logical_type', None) == constants.TIME_MILLIS)),
'long': lambda s, d: ((isinstance(d, (int, long))) and (LONG_MIN_VALUE <= d <= LONG_MAX_VALUE) or
(isinstance(d, datetime.time) and
getattr(s, 'logical_type', None) == constants.TIME_MICROS) or
(isinstance(d, datetime.date) and
_is_timezone_aware_datetime(d) and
getattr(s, 'logical_type', None) in (constants.TIMESTAMP_MILLIS,
constants.TIMESTAMP_MICROS))),
'float': lambda s, d: isinstance(d, (int, long, float)),
'fixed': lambda s, d: ((isinstance(d, bytes) and len(d) == s.size) or
(isinstance(d, Decimal) and
getattr(s, 'logical_type', None) == constants.DECIMAL)),
'enum': lambda s, d: d in s.symbols,
'array': lambda s, d: isinstance(d, list) and all(validate(s.items, item) for item in d),
'map': lambda s, d: (isinstance(d, dict) and all(isinstance(key, unicode) for key in d) and
all(validate(s.values, value) for value in d.values())),
'union': lambda s, d: any(validate(branch, d) for branch in s.schemas),
'record': lambda s, d: (isinstance(d, dict) and
all(validate(f.type, d.get(f.name)) for f in s.fields) and
{f.name for f in s.fields}.issuperset(d.keys())),
}
_valid['double'] = _valid['float']
_valid['error_union'] = _valid['union']
_valid['error'] = _valid['request'] = _valid['record']
def validate(expected_schema, datum):
"""Determines if a python datum is an instance of a schema.
Args:
expected_schema: Schema to validate against.
datum: Datum to validate.
Returns:
True if the datum is an instance of the schema.
"""
global _DEBUG_VALIDATE_INDENT
global _DEBUG_VALIDATE
expected_type = expected_schema.type
name = getattr(expected_schema, 'name', '')
if name:
name = ' ' + name
if expected_type in ('array', 'map', 'union', 'record'):
if _DEBUG_VALIDATE:
print('{!s}{!s}{!s}: {!s} {{'.format(' ' * _DEBUG_VALIDATE_INDENT, expected_schema.type, name, type(datum).__name__), file=sys.stderr)
_DEBUG_VALIDATE_INDENT += 2
if datum is not None and not datum:
print('{!s}<Empty>'.format(' ' * _DEBUG_VALIDATE_INDENT), file=sys.stderr)
result = _valid[expected_type](expected_schema, datum)
if _DEBUG_VALIDATE:
_DEBUG_VALIDATE_INDENT -= 2
print('{!s}}} -> {!s}'.format(' ' * _DEBUG_VALIDATE_INDENT, result), file=sys.stderr)
else:
result = _valid[expected_type](expected_schema, datum)
if _DEBUG_VALIDATE:
print('{!s}{!s}{!s}: {!s} -> {!s}'.format(' ' * _DEBUG_VALIDATE_INDENT,
expected_schema.type, name, type(datum).__name__, result), file=sys.stderr)
return result
#
# Decoder/Encoder
#
class BinaryDecoder(object):
"""Read leaf values."""
def __init__(self, reader):
"""
reader is a Python object on which we can call read, seek, and tell.
"""
self._reader = reader
# read-only properties
reader = property(lambda self: self._reader)
def read(self, n):
"""
Read n bytes.
"""
return self.reader.read(n)
def read_null(self):
"""
null is written as zero bytes
"""
return None
def read_boolean(self):
"""
a boolean is written as a single byte
whose value is either 0 (false) or 1 (true).
"""
return ord(self.read(1)) == 1
def read_int(self):
"""
int and long values are written using variable-length, zig-zag coding.
"""
return self.read_long()
def read_long(self):
"""
int and long values are written using variable-length, zig-zag coding.
"""
b = ord(self.read(1))
n = b & 0x7F
shift = 7
while (b & 0x80) != 0:
b = ord(self.read(1))
n |= (b & 0x7F) << shift
shift += 7
datum = (n >> 1) ^ -(n & 1)
return datum
def read_float(self):
"""
A float is written as 4 bytes.
The float is converted into a 32-bit integer using a method equivalent to
Java's floatToIntBits and then encoded in little-endian format.
"""
return STRUCT_FLOAT.unpack(self.read(4))[0]
def read_double(self):
"""
A double is written as 8 bytes.
The double is converted into a 64-bit integer using a method equivalent to
Java's doubleToLongBits and then encoded in little-endian format.
"""
return STRUCT_DOUBLE.unpack(self.read(8))[0]
def read_decimal_from_bytes(self, precision, scale):
"""
Decimal bytes are decoded as signed short, int or long depending on the
size of bytes.
"""
size = self.read_long()
return self.read_decimal_from_fixed(precision, scale, size)
def read_decimal_from_fixed(self, precision, scale, size):
"""
Decimal is encoded as fixed. Fixed instances are encoded using the
number of bytes declared in the schema.
"""
datum = self.read(size)
unscaled_datum = 0
msb = struct.unpack('!b', datum[0:1])[0]
leftmost_bit = (msb >> 7) & 1
if leftmost_bit == 1:
modified_first_byte = ord(datum[0:1]) ^ (1 << 7)
datum = bytearray([modified_first_byte]) + datum[1:]
for offset in range(size):
unscaled_datum <<= 8
unscaled_datum += ord(datum[offset:1 + offset])
unscaled_datum += pow(-2, (size * 8) - 1)
else:
for offset in range(size):
unscaled_datum <<= 8
unscaled_datum += ord(datum[offset:1 + offset])
original_prec = getcontext().prec
getcontext().prec = precision
scaled_datum = Decimal(unscaled_datum).scaleb(-scale)
getcontext().prec = original_prec
return scaled_datum
def read_bytes(self):
"""
Bytes are encoded as a long followed by that many bytes of data.
"""
return self.read(self.read_long())
def read_utf8(self):
"""
A string is encoded as a long followed by
that many bytes of UTF-8 encoded character data.
"""
return unicode(self.read_bytes(), "utf-8")
def read_date_from_int(self):
"""
int is decoded as python date object.
int stores the number of days from
the unix epoch, 1 January 1970 (ISO calendar).
"""
days_since_epoch = self.read_int()
return datetime.date(1970, 1, 1) + datetime.timedelta(days_since_epoch)
def _build_time_object(self, value, scale_to_micro):
value = value * scale_to_micro
value, microseconds = value // 1000000, value % 1000000
value, seconds = value // 60, value % 60
value, minutes = value // 60, value % 60
hours = value
return datetime.time(
hour=hours,
minute=minutes,
second=seconds,
microsecond=microseconds
)
def read_time_millis_from_int(self):
"""
int is decoded as python time object which represents
the number of milliseconds after midnight, 00:00:00.000.
"""
milliseconds = self.read_int()
return self._build_time_object(milliseconds, 1000)
def read_time_micros_from_long(self):
"""
long is decoded as python time object which represents
the number of microseconds after midnight, 00:00:00.000000.
"""
microseconds = self.read_long()
return self._build_time_object(microseconds, 1)
def read_timestamp_millis_from_long(self):
"""
long is decoded as python datetime object which represents
the number of milliseconds from the unix epoch, 1 January 1970.
"""
timestamp_millis = self.read_long()
timedelta = datetime.timedelta(microseconds=timestamp_millis * 1000)
unix_epoch_datetime = datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=timezones.utc)
return unix_epoch_datetime + timedelta
def read_timestamp_micros_from_long(self):
"""
long is decoded as python datetime object which represents
the number of microseconds from the unix epoch, 1 January 1970.
"""
timestamp_micros = self.read_long()
timedelta = datetime.timedelta(microseconds=timestamp_micros)
unix_epoch_datetime = datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=timezones.utc)
return unix_epoch_datetime + timedelta
def skip_null(self):
pass
def skip_boolean(self):
self.skip(1)
def skip_int(self):
self.skip_long()
def skip_long(self):
b = ord(self.read(1))
while (b & 0x80) != 0:
b = ord(self.read(1))
def skip_float(self):
self.skip(4)
def skip_double(self):
self.skip(8)
def skip_bytes(self):
self.skip(self.read_long())
def skip_utf8(self):
self.skip_bytes()
def skip(self, n):
self.reader.seek(self.reader.tell() + n)
class BinaryEncoder(object):
"""Write leaf values."""
def __init__(self, writer):
"""
writer is a Python object on which we can call write.
"""
self._writer = writer
# read-only properties
writer = property(lambda self: self._writer)
def write(self, datum):
"""Write an arbitrary datum."""
self.writer.write(datum)
def write_null(self, datum):
"""
null is written as zero bytes
"""
pass
def write_boolean(self, datum):
"""
a boolean is written as a single byte
whose value is either 0 (false) or 1 (true).
"""
self.write(bytearray([bool(datum)]))
def write_int(self, datum):
"""
int and long values are written using variable-length, zig-zag coding.
"""
self.write_long(datum)
def write_long(self, datum):
"""
int and long values are written using variable-length, zig-zag coding.
"""
datum = (datum << 1) ^ (datum >> 63)
while (datum & ~0x7F) != 0:
self.write(bytearray([(datum & 0x7f) | 0x80]))
datum >>= 7
self.write(bytearray([datum]))
def write_float(self, datum):
"""
A float is written as 4 bytes.
The float is converted into a 32-bit integer using a method equivalent to
Java's floatToIntBits and then encoded in little-endian format.
"""
self.write(STRUCT_FLOAT.pack(datum))
def write_double(self, datum):
"""
A double is written as 8 bytes.
The double is converted into a 64-bit integer using a method equivalent to
Java's doubleToLongBits and then encoded in little-endian format.
"""
self.write(STRUCT_DOUBLE.pack(datum))
def write_decimal_bytes(self, datum, scale):
"""
Decimal in bytes are encoded as long. Since size of packed value in bytes for
signed long is 8, 8 bytes are written.
"""
sign, digits, exp = datum.as_tuple()
if exp > scale:
raise AvroTypeException('Scale provided in schema does not match the decimal')
unscaled_datum = 0
for digit in digits:
unscaled_datum = (unscaled_datum * 10) + digit
bits_req = unscaled_datum.bit_length() + 1
if sign:
unscaled_datum = (1 << bits_req) - unscaled_datum
bytes_req = bits_req // 8
padding_bits = ~((1 << bits_req) - 1) if sign else 0
packed_bits = padding_bits | unscaled_datum
bytes_req += 1 if (bytes_req << 3) < bits_req else 0
self.write_long(bytes_req)
for index in range(bytes_req - 1, -1, -1):
bits_to_write = packed_bits >> (8 * index)
self.write(bytearray([bits_to_write & 0xff]))
def write_decimal_fixed(self, datum, scale, size):
"""
Decimal in fixed are encoded as size of fixed bytes.
"""
sign, digits, exp = datum.as_tuple()
if exp > scale:
raise AvroTypeException('Scale provided in schema does not match the decimal')
unscaled_datum = 0
for digit in digits:
unscaled_datum = (unscaled_datum * 10) + digit
bits_req = unscaled_datum.bit_length() + 1
size_in_bits = size * 8
offset_bits = size_in_bits - bits_req
mask = 2 ** size_in_bits - 1
bit = 1
for i in range(bits_req):
mask ^= bit
bit <<= 1
if bits_req < 8:
bytes_req = 1
else:
bytes_req = bits_req // 8
if bits_req % 8 != 0:
bytes_req += 1
if sign:
unscaled_datum = (1 << bits_req) - unscaled_datum
unscaled_datum = mask | unscaled_datum
for index in range(size - 1, -1, -1):
bits_to_write = unscaled_datum >> (8 * index)
self.write(bytearray([bits_to_write & 0xff]))
else:
for i in range(offset_bits // 8):
self.write(b'\x00')
for index in range(bytes_req - 1, -1, -1):
bits_to_write = unscaled_datum >> (8 * index)
self.write(bytearray([bits_to_write & 0xff]))
def write_bytes(self, datum):
"""
Bytes are encoded as a long followed by that many bytes of data.
"""
self.write_long(len(datum))
self.write(struct.pack('%ds' % len(datum), datum))
def write_utf8(self, datum):
"""
A string is encoded as a long followed by
that many bytes of UTF-8 encoded character data.
"""
datum = datum.encode("utf-8")
self.write_bytes(datum)
def write_date_int(self, datum):
"""
Encode python date object as int.
It stores the number of days from
the unix epoch, 1 January 1970 (ISO calendar).
"""
delta_date = datum - datetime.date(1970, 1, 1)
self.write_int(delta_date.days)
def write_time_millis_int(self, datum):
"""
Encode python time object as int.
It stores the number of milliseconds from midnight, 00:00:00.000
"""
milliseconds = datum.hour * 3600000 + datum.minute * 60000 + datum.second * 1000 + datum.microsecond // 1000
self.write_int(milliseconds)
def write_time_micros_long(self, datum):
"""
Encode python time object as long.
It stores the number of microseconds from midnight, 00:00:00.000000
"""
microseconds = datum.hour * 3600000000 + datum.minute * 60000000 + datum.second * 1000000 + datum.microsecond
self.write_long(microseconds)
def _timedelta_total_microseconds(self, timedelta):
return (
timedelta.microseconds + (timedelta.seconds + timedelta.days * 24 * 3600) * 10 ** 6)
def write_timestamp_millis_long(self, datum):
"""
Encode python datetime object as long.
It stores the number of milliseconds from midnight of unix epoch, 1 January 1970.
"""
datum = datum.astimezone(tz=timezones.utc)
timedelta = datum - datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=timezones.utc)
milliseconds = self._timedelta_total_microseconds(timedelta) / 1000
self.write_long(long(milliseconds))
def write_timestamp_micros_long(self, datum):
"""
Encode python datetime object as long.
It stores the number of microseconds from midnight of unix epoch, 1 January 1970.
"""
datum = datum.astimezone(tz=timezones.utc)
timedelta = datum - datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=timezones.utc)
microseconds = self._timedelta_total_microseconds(timedelta)
self.write_long(long(microseconds))
#
# DatumReader/Writer
#
class DatumReader(object):
"""Deserialize Avro-encoded data into a Python data structure."""
def __init__(self, writers_schema=None, readers_schema=None):
"""
As defined in the Avro specification, we call the schema encoded
in the data the "writer's schema", and the schema expected by the
reader the "reader's schema".
"""
self._writers_schema = writers_schema
self._readers_schema = readers_schema
# read/write properties
def set_writers_schema(self, writers_schema):
self._writers_schema = writers_schema
writers_schema = property(lambda self: self._writers_schema,
set_writers_schema)
def set_readers_schema(self, readers_schema):
self._readers_schema = readers_schema
readers_schema = property(lambda self: self._readers_schema,
set_readers_schema)
def read(self, decoder):
if self.readers_schema is None:
self.readers_schema = self.writers_schema
return self.read_data(self.writers_schema, self.readers_schema, decoder)
def read_data(self, writers_schema, readers_schema, decoder):
# schema matching
if not readers_schema.match(writers_schema):
fail_msg = 'Schemas do not match.'
raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
logical_type = getattr(writers_schema, 'logical_type', None)
# function dispatch for reading data based on type of writer's schema
if writers_schema.type in ['union', 'error_union']:
return self.read_union(writers_schema, readers_schema, decoder)
if readers_schema.type in ['union', 'error_union']:
# schema resolution: reader's schema is a union, writer's schema is not
for s in readers_schema.schemas:
if s.match(writers_schema):
return self.read_data(writers_schema, s, decoder)
# This shouldn't happen because of the match check at the start of this method.
fail_msg = 'Schemas do not match.'
raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
if writers_schema.type == 'null':
return decoder.read_null()
elif writers_schema.type == 'boolean':
return decoder.read_boolean()
elif writers_schema.type == 'string':
return decoder.read_utf8()
elif writers_schema.type == 'int':
if logical_type == constants.DATE:
return decoder.read_date_from_int()
if logical_type == constants.TIME_MILLIS:
return decoder.read_time_millis_from_int()
return decoder.read_int()
elif writers_schema.type == 'long':
if logical_type == constants.TIME_MICROS:
return decoder.read_time_micros_from_long()
elif logical_type == constants.TIMESTAMP_MILLIS:
return decoder.read_timestamp_millis_from_long()
elif logical_type == constants.TIMESTAMP_MICROS:
return decoder.read_timestamp_micros_from_long()
else:
return decoder.read_long()
elif writers_schema.type == 'float':
return decoder.read_float()
elif writers_schema.type == 'double':
return decoder.read_double()
elif writers_schema.type == 'bytes':
if logical_type == 'decimal':
return decoder.read_decimal_from_bytes(
writers_schema.get_prop('precision'),
writers_schema.get_prop('scale')
)
else:
return decoder.read_bytes()
elif writers_schema.type == 'fixed':
if logical_type == 'decimal':
return decoder.read_decimal_from_fixed(
writers_schema.get_prop('precision'),
writers_schema.get_prop('scale'),
writers_schema.size
)
return self.read_fixed(writers_schema, readers_schema, decoder)
elif writers_schema.type == 'enum':
return self.read_enum(writers_schema, readers_schema, decoder)
elif writers_schema.type == 'array':
return self.read_array(writers_schema, readers_schema, decoder)
elif writers_schema.type == 'map':
return self.read_map(writers_schema, readers_schema, decoder)
elif writers_schema.type in ['record', 'error', 'request']:
return self.read_record(writers_schema, readers_schema, decoder)
else:
fail_msg = "Cannot read unknown schema type: %s" % writers_schema.type
raise schema.AvroException(fail_msg)
def skip_data(self, writers_schema, decoder):
if writers_schema.type == 'null':
return decoder.skip_null()
elif writers_schema.type == 'boolean':
return decoder.skip_boolean()
elif writers_schema.type == 'string':
return decoder.skip_utf8()
elif writers_schema.type == 'int':
return decoder.skip_int()
elif writers_schema.type == 'long':
return decoder.skip_long()
elif writers_schema.type == 'float':
return decoder.skip_float()
elif writers_schema.type == 'double':
return decoder.skip_double()
elif writers_schema.type == 'bytes':
return decoder.skip_bytes()
elif writers_schema.type == 'fixed':
return self.skip_fixed(writers_schema, decoder)
elif writers_schema.type == 'enum':
return self.skip_enum(writers_schema, decoder)
elif writers_schema.type == 'array':
return self.skip_array(writers_schema, decoder)
elif writers_schema.type == 'map':
return self.skip_map(writers_schema, decoder)
elif writers_schema.type in ['union', 'error_union']:
return self.skip_union(writers_schema, decoder)
elif writers_schema.type in ['record', 'error', 'request']:
return self.skip_record(writers_schema, decoder)
else:
fail_msg = "Unknown schema type: %s" % writers_schema.type
raise schema.AvroException(fail_msg)
def read_fixed(self, writers_schema, readers_schema, decoder):
"""
Fixed instances are encoded using the number of bytes declared
in the schema.
"""
return decoder.read(writers_schema.size)
def skip_fixed(self, writers_schema, decoder):
return decoder.skip(writers_schema.size)
def read_enum(self, writers_schema, readers_schema, decoder):
"""
An enum is encoded by a int, representing the zero-based position
of the symbol in the schema.
"""
# read data
index_of_symbol = decoder.read_int()
if index_of_symbol >= len(writers_schema.symbols):
fail_msg = "Can't access enum index %d for enum with %d symbols"\
% (index_of_symbol, len(writers_schema.symbols))
raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
read_symbol = writers_schema.symbols[index_of_symbol]
# schema resolution
if read_symbol not in readers_schema.symbols:
fail_msg = "Symbol %s not present in Reader's Schema" % read_symbol
raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
return read_symbol
def skip_enum(self, writers_schema, decoder):
return decoder.skip_int()
def read_array(self, writers_schema, readers_schema, decoder):
"""
Arrays are encoded as a series of blocks.
Each block consists of a long count value,
followed by that many array items.
A block with count zero indicates the end of the array.
Each item is encoded per the array's item schema.
If a block's count is negative,
then the count is followed immediately by a long block size,
indicating the number of bytes in the block.
The actual count in this case
is the absolute value of the count written.
"""
read_items = []
block_count = decoder.read_long()
while block_count != 0:
if block_count < 0:
block_count = -block_count
block_size = decoder.read_long()
for i in range(block_count):
read_items.append(self.read_data(writers_schema.items,
readers_schema.items, decoder))
block_count = decoder.read_long()
return read_items
def skip_array(self, writers_schema, decoder):
block_count = decoder.read_long()
while block_count != 0:
if block_count < 0:
block_size = decoder.read_long()
decoder.skip(block_size)
else:
for i in range(block_count):
self.skip_data(writers_schema.items, decoder)
block_count = decoder.read_long()
def read_map(self, writers_schema, readers_schema, decoder):
"""
Maps are encoded as a series of blocks.
Each block consists of a long count value,
followed by that many key/value pairs.
A block with count zero indicates the end of the map.
Each item is encoded per the map's value schema.
If a block's count is negative,
then the count is followed immediately by a long block size,
indicating the number of bytes in the block.
The actual count in this case
is the absolute value of the count written.
"""
read_items = {}
block_count = decoder.read_long()
while block_count != 0:
if block_count < 0:
block_count = -block_count
block_size = decoder.read_long()
for i in range(block_count):
key = decoder.read_utf8()
read_items[key] = self.read_data(writers_schema.values,
readers_schema.values, decoder)
block_count = decoder.read_long()
return read_items
def skip_map(self, writers_schema, decoder):
block_count = decoder.read_long()
while block_count != 0:
if block_count < 0:
block_size = decoder.read_long()
decoder.skip(block_size)
else:
for i in range(block_count):
decoder.skip_utf8()
self.skip_data(writers_schema.values, decoder)
block_count = decoder.read_long()
def read_union(self, writers_schema, readers_schema, decoder):
"""
A union is encoded by first writing a long value indicating
the zero-based position within the union of the schema of its value.
The value is then encoded per the indicated schema within the union.
"""
# schema resolution
index_of_schema = int(decoder.read_long())
if index_of_schema >= len(writers_schema.schemas):
fail_msg = "Can't access branch index %d for union with %d branches"\
% (index_of_schema, len(writers_schema.schemas))
raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
selected_writers_schema = writers_schema.schemas[index_of_schema]
# read data
return self.read_data(selected_writers_schema, readers_schema, decoder)
def skip_union(self, writers_schema, decoder):
index_of_schema = int(decoder.read_long())
if index_of_schema >= len(writers_schema.schemas):
fail_msg = "Can't access branch index %d for union with %d branches"\
% (index_of_schema, len(writers_schema.schemas))
raise SchemaResolutionException(fail_msg, writers_schema)
return self.skip_data(writers_schema.schemas[index_of_schema], decoder)
def read_record(self, writers_schema, readers_schema, decoder):
"""
A record is encoded by encoding the values of its fields
in the order that they are declared. In other words, a record
is encoded as just the concatenation of the encodings of its fields.
Field values are encoded per their schema.
Schema Resolution:
* the ordering of fields may be different: fields are matched by name.
* schemas for fields with the same name in both records are resolved
recursively.
* if the writer's record contains a field with a name not present in the
reader's record, the writer's value for that field is ignored.
* if the reader's record schema has a field that contains a default value,
and writer's schema does not have a field with the same name, then the
reader should use the default value from its field.
* if the reader's record schema has a field with no default value, and
writer's schema does not have a field with the same name, then the
field's value is unset.
"""
# schema resolution
readers_fields_dict = readers_schema.fields_dict
read_record = {}
for field in writers_schema.fields:
readers_field = readers_fields_dict.get(field.name)
if readers_field is not None:
field_val = self.read_data(field.type, readers_field.type, decoder)
read_record[field.name] = field_val
else:
self.skip_data(field.type, decoder)
# fill in default values
if len(readers_fields_dict) > len(read_record):
writers_fields_dict = writers_schema.fields_dict
for field_name, field in readers_fields_dict.items():
if field_name not in writers_fields_dict:
if field.has_default:
field_val = self._read_default_value(field.type, field.default)
read_record[field.name] = field_val
else:
fail_msg = 'No default value for field %s' % field_name
raise SchemaResolutionException(fail_msg, writers_schema,
readers_schema)
return read_record
def skip_record(self, writers_schema, decoder):
for field in writers_schema.fields:
self.skip_data(field.type, decoder)
def _read_default_value(self, field_schema, default_value):
"""
Basically a JSON Decoder?
"""
if field_schema.type == 'null':
return None
elif field_schema.type == 'boolean':
return bool(default_value)
elif field_schema.type == 'int':
return int(default_value)
elif field_schema.type == 'long':
return long(default_value)
elif field_schema.type in ['float', 'double']:
return float(default_value)
elif field_schema.type in ['enum', 'fixed', 'string', 'bytes']:
return default_value
elif field_schema.type == 'array':
read_array = []
for json_val in default_value:
item_val = self._read_default_value(field_schema.items, json_val)
read_array.append(item_val)
return read_array
elif field_schema.type == 'map':
read_map = {}
for key, json_val in default_value.items():
map_val = self._read_default_value(field_schema.values, json_val)
read_map[key] = map_val
return read_map
elif field_schema.type in ['union', 'error_union']:
return self._read_default_value(field_schema.schemas[0], default_value)
elif field_schema.type == 'record':
read_record = {}
for field in field_schema.fields:
json_val = default_value.get(field.name)
if json_val is None:
json_val = field.default
field_val = self._read_default_value(field.type, json_val)
read_record[field.name] = field_val
return read_record
else:
fail_msg = 'Unknown type: %s' % field_schema.type
raise schema.AvroException(fail_msg)
class DatumWriter(object):
"""DatumWriter for generic python objects."""
def __init__(self, writers_schema=None):
self._writers_schema = writers_schema
# read/write properties
def set_writers_schema(self, writers_schema):
self._writers_schema = writers_schema
writers_schema = property(lambda self: self._writers_schema,
set_writers_schema)
def write(self, datum, encoder):
if not validate(self.writers_schema, datum):
raise AvroTypeException(self.writers_schema, datum)
self.write_data(self.writers_schema, datum, encoder)
def write_data(self, writers_schema, datum, encoder):
# function dispatch to write datum
logical_type = getattr(writers_schema, 'logical_type', None)
if writers_schema.type == 'null':
encoder.write_null(datum)
elif writers_schema.type == 'boolean':
encoder.write_boolean(datum)
elif writers_schema.type == 'string':
encoder.write_utf8(datum)
elif writers_schema.type == 'int':
if logical_type == constants.DATE:
encoder.write_date_int(datum)
elif logical_type == constants.TIME_MILLIS:
encoder.write_time_millis_int(datum)
else:
encoder.write_int(datum)
elif writers_schema.type == 'long':
if logical_type == constants.TIME_MICROS:
encoder.write_time_micros_long(datum)
elif logical_type == constants.TIMESTAMP_MILLIS:
encoder.write_timestamp_millis_long(datum)
elif logical_type == constants.TIMESTAMP_MICROS:
encoder.write_timestamp_micros_long(datum)
else:
encoder.write_long(datum)
elif writers_schema.type == 'float':
encoder.write_float(datum)
elif writers_schema.type == 'double':
encoder.write_double(datum)
elif writers_schema.type == 'bytes':
if logical_type == 'decimal':
encoder.write_decimal_bytes(datum, writers_schema.get_prop('scale'))
else:
encoder.write_bytes(datum)
elif writers_schema.type == 'fixed':
if logical_type == 'decimal':
encoder.write_decimal_fixed(
datum,
writers_schema.get_prop('scale'),
writers_schema.get_prop('size')
)
else:
self.write_fixed(writers_schema, datum, encoder)
elif writers_schema.type == 'enum':
self.write_enum(writers_schema, datum, encoder)
elif writers_schema.type == 'array':
self.write_array(writers_schema, datum, encoder)
elif writers_schema.type == 'map':
self.write_map(writers_schema, datum, encoder)
elif writers_schema.type in ['union', 'error_union']:
self.write_union(writers_schema, datum, encoder)
elif writers_schema.type in ['record', 'error', 'request']:
self.write_record(writers_schema, datum, encoder)
else:
fail_msg = 'Unknown type: %s' % writers_schema.type
raise schema.AvroException(fail_msg)
def write_fixed(self, writers_schema, datum, encoder):
"""
Fixed instances are encoded using the number of bytes declared
in the schema.
"""
encoder.write(datum)
def write_enum(self, writers_schema, datum, encoder):
"""
An enum is encoded by a int, representing the zero-based position
of the symbol in the schema.
"""
index_of_datum = writers_schema.symbols.index(datum)
encoder.write_int(index_of_datum)
def write_array(self, writers_schema, datum, encoder):
"""
Arrays are encoded as a series of blocks.
Each block consists of a long count value,
followed by that many array items.
A block with count zero indicates the end of the array.
Each item is encoded per the array's item schema.
If a block's count is negative,
then the count is followed immediately by a long block size,
indicating the number of bytes in the block.
The actual count in this case
is the absolute value of the count written.
"""
if len(datum) > 0:
encoder.write_long(len(datum))
for item in datum:
self.write_data(writers_schema.items, item, encoder)
encoder.write_long(0)
def write_map(self, writers_schema, datum, encoder):
"""
Maps are encoded as a series of blocks.
Each block consists of a long count value,
followed by that many key/value pairs.
A block with count zero indicates the end of the map.
Each item is encoded per the map's value schema.
If a block's count is negative,
then the count is followed immediately by a long block size,
indicating the number of bytes in the block.
The actual count in this case
is the absolute value of the count written.
"""
if len(datum) > 0:
encoder.write_long(len(datum))
for key, val in datum.items():
encoder.write_utf8(key)
self.write_data(writers_schema.values, val, encoder)
encoder.write_long(0)
def write_union(self, writers_schema, datum, encoder):
"""
A union is encoded by first writing a long value indicating
the zero-based position within the union of the schema of its value.
The value is then encoded per the indicated schema within the union.
"""
# resolve union
index_of_schema = -1
for i, candidate_schema in enumerate(writers_schema.schemas):
if validate(candidate_schema, datum):
index_of_schema = i
if index_of_schema < 0:
raise AvroTypeException(writers_schema, datum)
# write data
encoder.write_long(index_of_schema)
self.write_data(writers_schema.schemas[index_of_schema], datum, encoder)
def write_record(self, writers_schema, datum, encoder):
"""
A record is encoded by encoding the values of its fields
in the order that they are declared. In other words, a record
is encoded as just the concatenation of the encodings of its fields.
Field values are encoded per their schema.
"""
for field in writers_schema.fields:
self.write_data(field.type, datum.get(field.name), encoder)