| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| |
| import copy |
| from abc import abstractmethod |
| from collections import OrderedDict |
| from enum import Enum, EnumMeta |
| |
| |
| def _string_representation(x): |
| if hasattr(x, "__name__"): |
| return x.__name__ |
| else: |
| return str(x) |
| |
| |
| def _check_record_or_field(x): |
| if (type(x) is type and not issubclass(x, Record)) \ |
| and not isinstance(x, Field): |
| raise Exception('Argument ' + _string_representation(x) + ' is not a Record or a Field') |
| |
| |
| class RecordMeta(type): |
| def __new__(metacls, name, parents, dct): |
| if name != 'Record': |
| # Do not apply this logic to the base class itself |
| dct['_fields'] = RecordMeta._get_fields(dct) |
| dct['_required'] = False |
| return type.__new__(metacls, name, parents, dct) |
| |
| @classmethod |
| def _get_fields(cls, dct): |
| # Build a set of valid fields for this record |
| fields = OrderedDict() |
| for name, value in dct.items(): |
| if issubclass(type(value), EnumMeta): |
| value = CustomEnum(value) |
| elif type(value) == RecordMeta: |
| # We expect an instance of a record rather than the class itself |
| value = value() |
| |
| if isinstance(value, Record) or isinstance(value, Field): |
| fields[name] = value |
| return fields |
| |
| |
| class Record(metaclass=RecordMeta): |
| |
| # This field is used to set namespace for Avro Record schema. |
| _avro_namespace = None |
| |
| # Generate a schema where fields are sorted alphabetically |
| _sorted_fields = False |
| |
| def __init__(self, default=None, required_default=False, required=False, *args, **kwargs): |
| self._required_default = required_default |
| self._default = default |
| self._required = required |
| |
| for k, value in self._fields.items(): |
| if k in kwargs: |
| if isinstance(value, Record) and isinstance(kwargs[k], dict): |
| # Use dict init Record object |
| copied = copy.copy(value) |
| copied.__init__(**kwargs[k]) |
| self.__setattr__(k, copied) |
| elif isinstance(value, Array) and isinstance(kwargs[k], list) and len(kwargs[k]) > 0 \ |
| and isinstance(value.array_type, Record) and isinstance(kwargs[k][0], dict): |
| arr = [] |
| for item in kwargs[k]: |
| copied = copy.copy(value.array_type) |
| copied.__init__(**item) |
| arr.append(copied) |
| self.__setattr__(k, arr) |
| elif isinstance(value, Map) and isinstance(kwargs[k], dict) and len(kwargs[k]) > 0 \ |
| and isinstance(value.value_type, Record) and isinstance(list(kwargs[k].values())[0], dict): |
| dic = {} |
| for mapKey, mapValue in kwargs[k].items(): |
| copied = copy.copy(value.value_type) |
| copied.__init__(**mapValue) |
| dic[mapKey] = copied |
| self.__setattr__(k, dic) |
| else: |
| # Value was overridden at constructor |
| self.__setattr__(k, kwargs[k]) |
| elif isinstance(value, Record): |
| # Value is a subrecord |
| self.__setattr__(k, value) |
| else: |
| # Set field to default value, without revalidating the default value type |
| super(Record, self).__setattr__(k, value.default()) |
| |
| @classmethod |
| def schema(cls): |
| return cls.schema_info(set()) |
| |
| @classmethod |
| def schema_info(cls, defined_names): |
| namespace_prefix = '' |
| if cls._avro_namespace is not None: |
| namespace_prefix = cls._avro_namespace + '.' |
| namespace_name = namespace_prefix + cls.__name__ |
| |
| if namespace_name in defined_names: |
| return namespace_name |
| |
| defined_names.add(namespace_name) |
| |
| schema = { |
| 'type': 'record', |
| 'name': str(cls.__name__) |
| } |
| if cls._avro_namespace is not None: |
| schema['namespace'] = cls._avro_namespace |
| schema['fields'] = [] |
| |
| def get_filed_default_value(value): |
| if isinstance(value, Enum): |
| return value.name |
| else: |
| return value |
| |
| if cls._sorted_fields: |
| fields = sorted(cls._fields.keys()) |
| else: |
| fields = cls._fields.keys() |
| for name in fields: |
| field = cls._fields[name] |
| field_type = field.schema_info(defined_names) \ |
| if field._required else ['null', field.schema_info(defined_names)] |
| schema['fields'].append({ |
| 'name': name, |
| 'default': get_filed_default_value(field.default()), |
| 'type': field_type |
| }) if field.required_default() else schema['fields'].append({ |
| 'name': name, |
| 'type': field_type, |
| }) |
| |
| return schema |
| |
| def __setattr__(self, key, value): |
| if key == '_default': |
| super(Record, self).__setattr__(key, value) |
| elif key == '_required_default': |
| super(Record, self).__setattr__(key, value) |
| elif key == '_required': |
| super(Record, self).__setattr__(key, value) |
| else: |
| if key not in self._fields: |
| raise AttributeError('Cannot set undeclared field ' + key + ' on record') |
| |
| # Check that type of value matches the field type |
| field = self._fields[key] |
| value = field.validate_type(key, value) |
| super(Record, self).__setattr__(key, value) |
| |
| def __eq__(self, other): |
| for field in self._fields: |
| if self.__getattribute__(field) != other.__getattribute__(field): |
| return False |
| return True |
| |
| def __ne__(self, other): |
| return not self.__eq__(other) |
| |
| def __str__(self): |
| return str(self.__dict__) |
| |
| def type(self): |
| return str(self.__class__.__name__) |
| |
| def python_type(self): |
| return self.__class__ |
| |
| def validate_type(self, name, val): |
| if val is None and not self._required: |
| return self.default() |
| |
| if not isinstance(val, self.__class__): |
| raise TypeError("Invalid type '%s' for sub-record field '%s'. Expected: %s" % ( |
| type(val), name, _string_representation(self.__class__))) |
| return val |
| |
| def default(self): |
| if self._default is not None: |
| return self._default |
| else: |
| return None |
| |
| def required_default(self): |
| return self._required_default |
| |
| |
| class Field(object): |
| def __init__(self, default=None, required=False, required_default=False): |
| if default is not None: |
| default = self.validate_type('default', default) |
| self._default = default |
| self._required_default = required_default |
| self._required = required |
| |
| @abstractmethod |
| def type(self): |
| pass |
| |
| @abstractmethod |
| def python_type(self): |
| pass |
| |
| def validate_type(self, name, val): |
| if val is None and not self._required: |
| return self.default() |
| |
| if not isinstance(val, self.python_type()): |
| raise TypeError("Invalid type '%s' for field '%s'. Expected: %s" % (type(val), name, _string_representation(self.python_type()))) |
| return val |
| |
| def schema(self): |
| # For primitive types, the schema would just be the type itself |
| return self.type() |
| |
| def schema_info(self, defined_names): |
| return self.type() |
| |
| def default(self): |
| return self._default |
| |
| def required_default(self): |
| return self._required_default |
| |
| |
| # All types |
| |
| |
| class Null(Field): |
| def type(self): |
| return 'null' |
| |
| def python_type(self): |
| return type(None) |
| |
| def validate_type(self, name, val): |
| if val is not None: |
| raise TypeError('Field ' + name + ' is set to be None') |
| return val |
| |
| |
| class Boolean(Field): |
| def type(self): |
| return 'boolean' |
| |
| def python_type(self): |
| return bool |
| |
| def default(self): |
| if self._default is not None: |
| return self._default |
| else: |
| return False |
| |
| |
| class Integer(Field): |
| def type(self): |
| return 'int' |
| |
| def python_type(self): |
| return int |
| |
| def default(self): |
| if self._default is not None: |
| return self._default |
| else: |
| return None |
| |
| |
| class Long(Field): |
| def type(self): |
| return 'long' |
| |
| def python_type(self): |
| return int |
| |
| def default(self): |
| if self._default is not None: |
| return self._default |
| else: |
| return None |
| |
| |
| class Float(Field): |
| def type(self): |
| return 'float' |
| |
| def python_type(self): |
| return float, int |
| |
| def default(self): |
| if self._default is not None: |
| return self._default |
| else: |
| return None |
| |
| |
| class Double(Field): |
| def type(self): |
| return 'double' |
| |
| def python_type(self): |
| return float, int |
| |
| def default(self): |
| if self._default is not None: |
| return self._default |
| else: |
| return None |
| |
| |
| class Bytes(Field): |
| def type(self): |
| return 'bytes' |
| |
| def python_type(self): |
| return bytes, str |
| |
| def default(self): |
| if self._default is not None: |
| return self._default |
| else: |
| return None |
| |
| def validate_type(self, name, val): |
| if isinstance(val, str): |
| return val.encode() |
| return val |
| |
| |
| class String(Field): |
| def type(self): |
| return 'string' |
| |
| def python_type(self): |
| return str, bytes |
| |
| def validate_type(self, name, val): |
| t = type(val) |
| |
| if val is None and not self._required: |
| return self.default() |
| |
| if not (isinstance(val, (str, bytes)) or t.__name__ == 'unicode'): |
| raise TypeError("Invalid type '%s' for field '%s'. Expected a string" % (t, name)) |
| if isinstance(val, bytes): |
| return val.decode() |
| return val |
| |
| def default(self): |
| if self._default is not None: |
| return self._default |
| else: |
| return None |
| |
| # Complex types |
| |
| |
| class CustomEnum(Field): |
| def __init__(self, enum_type, default=None, required=False, required_default=False): |
| if not issubclass(enum_type, Enum): |
| raise Exception(_string_representation(enum_type) + " is not a valid Enum type") |
| self.enum_type = enum_type |
| self.values = {} |
| for x in enum_type.__members__.values(): |
| self.values[x.value] = x |
| super(CustomEnum, self).__init__(default, required, required_default) |
| |
| def type(self): |
| return 'enum' |
| |
| def python_type(self): |
| return self.enum_type |
| |
| def validate_type(self, name, val): |
| if val is None: |
| return None |
| |
| if type(val) is str: |
| # The enum was passed as a string, we need to check it against the possible values |
| if val in self.enum_type.__members__: |
| return self.enum_type.__members__[val] |
| else: |
| raise TypeError( |
| "Invalid enum value '%s' for field '%s'. Expected: %s" % (val, name, self.enum_type.__members__.keys())) |
| elif type(val) is int: |
| # The enum was passed as an int, we need to check it against the possible values |
| if val in self.values: |
| return self.values[val] |
| else: |
| raise TypeError( |
| "Invalid enum value '%s' for field '%s'. Expected: %s" % (val, name, self.values.keys())) |
| elif not isinstance(val, self.python_type()): |
| raise TypeError("Invalid type '%s' for field '%s'. Expected: %s" % (type(val), name, _string_representation(self.python_type()))) |
| else: |
| return val |
| |
| def schema(self): |
| return self.schema_info(set()) |
| |
| def schema_info(self, defined_names): |
| if self.enum_type.__name__ in defined_names: |
| return self.enum_type.__name__ |
| defined_names.add(self.enum_type.__name__) |
| return { |
| 'type': self.type(), |
| 'name': self.enum_type.__name__, |
| 'symbols': [x.name for x in self.enum_type] |
| } |
| |
| def default(self): |
| if self._default is not None: |
| return self._default |
| else: |
| return None |
| |
| |
| class Array(Field): |
| def __init__(self, array_type, default=None, required=False, required_default=False): |
| _check_record_or_field(array_type) |
| self.array_type = array_type |
| super(Array, self).__init__(default=default, required=required, required_default=required_default) |
| |
| def type(self): |
| return 'array' |
| |
| def python_type(self): |
| return list |
| |
| def validate_type(self, name, val): |
| if val is None: |
| return None |
| |
| super(Array, self).validate_type(name, val) |
| |
| for x in val: |
| if not isinstance(x, self.array_type.python_type()): |
| raise TypeError('Array field ' + name + ' items should all be of type ' + |
| _string_representation(self.array_type.type())) |
| return val |
| |
| def schema(self): |
| return self.schema_info(set()) |
| |
| def schema_info(self, defined_names): |
| return { |
| 'type': self.type(), |
| 'items': self.array_type.schema_info(defined_names) if isinstance(self.array_type, (Array, Map, Record)) |
| else self.array_type.type() |
| } |
| |
| def default(self): |
| if self._default is not None: |
| return self._default |
| else: |
| return None |
| |
| |
| class Map(Field): |
| def __init__(self, value_type, default=None, required=False, required_default=False): |
| _check_record_or_field(value_type) |
| self.value_type = value_type |
| super(Map, self).__init__(default=default, required=required, required_default=required_default) |
| |
| def type(self): |
| return 'map' |
| |
| def python_type(self): |
| return dict |
| |
| def validate_type(self, name, val): |
| if val is None: |
| return None |
| |
| super(Map, self).validate_type(name, val) |
| |
| for k, v in val.items(): |
| if type(k) != str and not is_unicode(k): |
| raise TypeError('Map keys for field ' + name + ' should all be strings') |
| if not isinstance(v, self.value_type.python_type()): |
| raise TypeError('Map values for field ' + name + ' should all be of type ' |
| + _string_representation(self.value_type.python_type())) |
| |
| return val |
| |
| def schema(self): |
| return self.schema_info(set()) |
| |
| def schema_info(self, defined_names): |
| return { |
| 'type': self.type(), |
| 'values': self.value_type.schema_info(defined_names) if isinstance(self.value_type, (Array, Map, Record)) |
| else self.value_type.type() |
| } |
| |
| def default(self): |
| if self._default is not None: |
| return self._default |
| else: |
| return None |
| |
| |
| # Python3 has no `unicode` type, so here we use a tricky way to check if the type of `x` is `unicode` in Python2 |
| # and also make it work well with Python3. |
| def is_unicode(x): |
| return 'encode' in dir(x) and type(x.encode()) == str |