AVRO-2748: Refactor Schema Matching (#861)
Move schema matching into schema objects because each type knows its matching logic best.
diff --git a/lang/py/avro/io.py b/lang/py/avro/io.py
index 52b631a..ffbd0af 100644
--- a/lang/py/avro/io.py
+++ b/lang/py/avro/io.py
@@ -585,61 +585,8 @@
#
# DatumReader/Writer
#
-
class DatumReader(object):
"""Deserialize Avro-encoded data into a Python data structure."""
- @staticmethod
- def check_props(schema_one, schema_two, prop_list):
- for prop in prop_list:
- if getattr(schema_one, prop) != getattr(schema_two, prop):
- return False
- return True
-
- @staticmethod
- def match_schemas(writers_schema, readers_schema):
- w_type = writers_schema.type
- r_type = readers_schema.type
- if 'union' in [w_type, r_type] or 'error_union' in [w_type, r_type]:
- return True
- elif (w_type in schema.PRIMITIVE_TYPES and r_type in schema.PRIMITIVE_TYPES
- and w_type == r_type):
- return True
- elif (w_type == r_type == 'record' and
- DatumReader.check_props(writers_schema, readers_schema,
- ['fullname'])):
- return True
- elif (w_type == r_type == 'error' and
- DatumReader.check_props(writers_schema, readers_schema,
- ['fullname'])):
- return True
- elif (w_type == r_type == 'request'):
- return True
- elif (w_type == r_type == 'fixed' and
- DatumReader.check_props(writers_schema, readers_schema,
- ['fullname', 'size'])):
- return True
- elif (w_type == r_type == 'enum' and
- DatumReader.check_props(writers_schema, readers_schema,
- ['fullname'])):
- return True
- elif (w_type == r_type == 'map' and
- DatumReader.check_props(writers_schema.values,
- readers_schema.values, ['type'])):
- return True
- elif (w_type == r_type == 'array' and
- DatumReader.check_props(writers_schema.items,
- readers_schema.items, ['type'])):
- return True
-
- # Handle schema promotion
- if w_type == 'int' and r_type in ['long', 'float', 'double']:
- return True
- elif w_type == 'long' and r_type in ['float', 'double']:
- return True
- elif w_type == 'float' and r_type == 'double':
- return True
- return False
-
def __init__(self, writers_schema=None, readers_schema=None):
"""
As defined in the Avro specification, we call the schema encoded
@@ -658,7 +605,6 @@
self._readers_schema = readers_schema
readers_schema = property(lambda self: self._readers_schema,
set_readers_schema)
-
def read(self, decoder):
if self.readers_schema is None:
self.readers_schema = self.writers_schema
@@ -666,21 +612,26 @@
def read_data(self, writers_schema, readers_schema, decoder):
# schema matching
- if not DatumReader.match_schemas(writers_schema, readers_schema):
+ if not readers_schema.match(writers_schema):
fail_msg = 'Schemas do not match.'
raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
logical_type = getattr(writers_schema, 'logical_type', None)
- # schema resolution: reader's schema is a union, writer's schema is not
- if (writers_schema.type not in ['union', 'error_union']
- and readers_schema.type in ['union', 'error_union']):
+
+ # function dispatch for reading data based on type of writer's schema
+ if writers_schema.type in ['union', 'error_union']:
+ return self.read_union(writers_schema, readers_schema, decoder)
+
+ if readers_schema.type in ['union', 'error_union']:
+ # schema resolution: reader's schema is a union, writer's schema is not
for s in readers_schema.schemas:
- if DatumReader.match_schemas(writers_schema, s):
+ if s.match(writers_schema):
return self.read_data(writers_schema, s, decoder)
+
+ # This shouldn't happen because of the match check at the start of this method.
fail_msg = 'Schemas do not match.'
raise SchemaResolutionException(fail_msg, writers_schema, readers_schema)
- # function dispatch for reading data based on type of writer's schema
if writers_schema.type == 'null':
return decoder.read_null()
elif writers_schema.type == 'boolean':
@@ -728,8 +679,6 @@
return self.read_array(writers_schema, readers_schema, decoder)
elif writers_schema.type == 'map':
return self.read_map(writers_schema, readers_schema, decoder)
- elif writers_schema.type in ['union', 'error_union']:
- return self.read_union(writers_schema, readers_schema, decoder)
elif writers_schema.type in ['record', 'error', 'request']:
return self.read_record(writers_schema, readers_schema, decoder)
else:
diff --git a/lang/py/avro/schema.py b/lang/py/avro/schema.py
index f41852d..1248ec5 100644
--- a/lang/py/avro/schema.py
+++ b/lang/py/avro/schema.py
@@ -177,6 +177,23 @@
other_props = property(lambda self: get_other_props(self._props, SCHEMA_RESERVED_PROPS),
doc="dictionary of non-reserved properties")
+ def check_props(self, other, props):
+ """Check that the given props are identical in two schemas.
+
+ @arg other: The other schema to check
+ @arg props: An iterable of properties to check
+ @return bool: True if all the properties match
+ """
+ return all(getattr(self, prop) == getattr(other, prop) for prop in props)
+
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the writer schema.
+
+ @arg writer: the writer schema to match against.
+ @return bool
+ """
+ raise NotImplemented("Must be implemented by subclasses")
+
# utility functions to manipulate properties dict
def get_prop(self, key):
return self._props.get(key)
@@ -460,6 +477,19 @@
self.fullname = type
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the writer schema.
+
+ @arg writer: the schema to match against
+ @return bool
+ """
+ return self.type == writer.type or {
+ 'float': self.type == 'double',
+ 'int': self.type in {'double', 'float', 'long'},
+ 'long': self.type in {'double', 'float',},
+ }.get(writer.type, False)
+
+
def to_json(self, names=None):
if len(self.props) == 1:
return self.fullname
@@ -472,7 +502,6 @@
#
# Decimal Bytes Type
#
-
class BytesDecimalSchema(PrimitiveSchema, DecimalLogicalSchema):
def __init__(self, precision, scale=0, other_props=None):
DecimalLogicalSchema.__init__(self, precision, scale, max_precision=((1 << 31) - 1))
@@ -494,7 +523,6 @@
#
# Complex Types (non-recursive)
#
-
class FixedSchema(NamedSchema):
def __init__(self, name, namespace, size, names=None, other_props=None):
# Ensure valid ctor args
@@ -511,6 +539,14 @@
# read-only properties
size = property(lambda self: self.get_prop('size'))
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the writer schema.
+
+ @arg writer: the schema to match against
+ @return bool
+ """
+ return self.type == writer.type and self.check_props(writer, ['fullname', 'size'])
+
def to_json(self, names=None):
if names is None:
names = Names()
@@ -574,6 +610,14 @@
symbols = property(lambda self: self.get_prop('symbols'))
doc = property(lambda self: self.get_prop('doc'))
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the writer schema.
+
+ @arg writer: the schema to match against
+ @return bool
+ """
+ return self.type == writer.type and self.check_props(writer, ['fullname'])
+
def to_json(self, names=None):
if names is None:
names = Names()
@@ -610,6 +654,14 @@
# read-only properties
items = property(lambda self: self.get_prop('items'))
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the writer schema.
+
+ @arg writer: the schema to match against
+ @return bool
+ """
+ return self.type == writer.type and self.items.check_props(writer.items, ['type'])
+
def to_json(self, names=None):
if names is None:
names = Names()
@@ -643,6 +695,14 @@
# read-only properties
values = property(lambda self: self.get_prop('values'))
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the writer schema.
+
+ @arg writer: the schema to match against
+ @return bool
+ """
+ return writer.type == self.type and self.values.check_props(writer.values, ['type'])
+
def to_json(self, names=None):
if names is None:
names = Names()
@@ -690,6 +750,14 @@
# read-only properties
schemas = property(lambda self: self._schemas)
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the writer schema.
+
+ @arg writer: the schema to match against
+ @return bool
+ """
+ return writer.type in {'union', 'error_union'} or any(s.match(writer) for s in self.schemas)
+
def to_json(self, names=None):
if names is None:
names = Names()
@@ -750,6 +818,14 @@
field_objects.append(new_field)
return field_objects
+ def match(self, writer):
+ """Return True if the current schema (as reader) matches the other schema.
+
+ @arg writer: the schema to match against
+ @return bool
+ """
+ return writer.type == self.type and (self.type == 'request' or self.check_props(writer, ['fullname']))
+
def __init__(self, name, namespace, fields, names=None, schema_type='record',
doc=None, other_props=None):
# Ensure valid ctor args