blob: 6343003bd4da8f9b40c23eb5912f93da4a78bd9a [file] [log] [blame]
#!/usr/bin/env python3
##
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import copy
from enum import Enum
from typing import Container, Iterable, List, Optional, Set, cast
from avro.errors import AvroRuntimeException
from avro.schema import (
ArraySchema,
EnumSchema,
Field,
FixedSchema,
MapSchema,
NamedSchema,
RecordSchema,
Schema,
UnionSchema,
)
class SchemaType(str, Enum):
ARRAY = "array"
BOOLEAN = "boolean"
BYTES = "bytes"
DOUBLE = "double"
ENUM = "enum"
FIXED = "fixed"
FLOAT = "float"
INT = "int"
LONG = "long"
MAP = "map"
NULL = "null"
RECORD = "record"
STRING = "string"
UNION = "union"
def __str__(self):
return self.value
class SchemaCompatibilityType(Enum):
compatible = "compatible"
incompatible = "incompatible"
recursion_in_progress = "recursion_in_progress"
class SchemaIncompatibilityType(Enum):
name_mismatch = "name_mismatch"
fixed_size_mismatch = "fixed_size_mismatch"
missing_enum_symbols = "missing_enum_symbols"
reader_field_missing_default_value = "reader_field_missing_default_value"
type_mismatch = "type_mismatch"
missing_union_branch = "missing_union_branch"
PRIMITIVE_TYPES = {
SchemaType.NULL,
SchemaType.BOOLEAN,
SchemaType.INT,
SchemaType.LONG,
SchemaType.FLOAT,
SchemaType.DOUBLE,
SchemaType.BYTES,
SchemaType.STRING,
}
class SchemaCompatibilityResult:
def __init__(
self,
compatibility: SchemaCompatibilityType = SchemaCompatibilityType.recursion_in_progress,
incompatibilities: Optional[List[SchemaIncompatibilityType]] = None,
messages: Optional[Set[str]] = None,
locations: Optional[Set[str]] = None,
):
self.locations = locations or {"/"}
self.messages = messages or set()
self.compatibility = compatibility
self.incompatibilities = incompatibilities or []
def merge(this: SchemaCompatibilityResult, that: SchemaCompatibilityResult) -> SchemaCompatibilityResult:
"""
Merges two {@code SchemaCompatibilityResult} into a new instance, combining the list of Incompatibilities
and regressing to the SchemaCompatibilityType.incompatible state if any incompatibilities are encountered.
:param this: SchemaCompatibilityResult
:param that: SchemaCompatibilityResult
:return: SchemaCompatibilityResult
"""
that = cast(SchemaCompatibilityResult, that)
merged = [*copy(this.incompatibilities), *copy(that.incompatibilities)]
if this.compatibility is SchemaCompatibilityType.compatible:
compat = that.compatibility
messages = that.messages
locations = that.locations
else:
compat = this.compatibility
messages = this.messages.union(that.messages)
locations = this.locations.union(that.locations)
return SchemaCompatibilityResult(
compatibility=compat,
incompatibilities=merged,
messages=messages,
locations=locations,
)
CompatibleResult = SchemaCompatibilityResult(SchemaCompatibilityType.compatible)
class ReaderWriter:
def __init__(self, reader: Schema, writer: Schema) -> None:
self.reader, self.writer = reader, writer
def __hash__(self) -> int:
return id(self.reader) ^ id(self.writer)
def __eq__(self, other) -> bool:
if not isinstance(other, ReaderWriter):
return False
return self.reader is other.reader and self.writer is other.writer
class ReaderWriterCompatibilityChecker:
ROOT_REFERENCE_TOKEN = "/"
def __init__(self):
self.memoize_map = {}
def get_compatibility(
self,
reader: Schema,
writer: Schema,
reference_token: str = ROOT_REFERENCE_TOKEN,
location: Optional[List[str]] = None,
) -> SchemaCompatibilityResult:
if location is None:
location = []
pair = ReaderWriter(reader, writer)
if pair in self.memoize_map:
result = cast(SchemaCompatibilityResult, self.memoize_map[pair])
if result.compatibility is SchemaCompatibilityType.recursion_in_progress:
result = CompatibleResult
else:
self.memoize_map[pair] = SchemaCompatibilityResult()
result = self.calculate_compatibility(reader, writer, location + [reference_token])
self.memoize_map[pair] = result
return result
# pylSchemaType.INT: disable=too-many-return-statements
def calculate_compatibility(
self,
reader: Schema,
writer: Schema,
location: List[str],
) -> SchemaCompatibilityResult:
"""
Calculates the compatibility of a reader/writer schema pair. Will be positive if the reader is capable of reading
whatever the writer may write
:param reader: avro.schema.Schema
:param writer: avro.schema.Schema
:param location: List[str]
:return: SchemaCompatibilityResult
"""
assert reader is not None
assert writer is not None
result = CompatibleResult
if reader.type == writer.type:
if reader.type in PRIMITIVE_TYPES:
return result
if reader.type == SchemaType.ARRAY:
reader, writer = cast(ArraySchema, reader), cast(ArraySchema, writer)
return merge(
result,
self.get_compatibility(reader.items, writer.items, "items", location),
)
if reader.type == SchemaType.MAP:
reader, writer = cast(MapSchema, reader), cast(MapSchema, writer)
return merge(
result,
self.get_compatibility(reader.values, writer.values, "values", location),
)
if reader.type == SchemaType.FIXED:
reader, writer = cast(FixedSchema, reader), cast(FixedSchema, writer)
result = merge(result, check_schema_names(reader, writer, location))
return merge(result, check_fixed_size(reader, writer, location))
if reader.type == SchemaType.ENUM:
reader, writer = cast(EnumSchema, reader), cast(EnumSchema, writer)
result = merge(result, check_schema_names(reader, writer, location))
return merge(
result,
check_reader_enum_contains_writer_enum(reader, writer, location),
)
if reader.type == SchemaType.RECORD:
reader, writer = cast(RecordSchema, reader), cast(RecordSchema, writer)
result = merge(result, check_schema_names(reader, writer, location))
return merge(
result,
self.check_reader_writer_record_fields(reader, writer, location),
)
if reader.type == SchemaType.UNION:
reader, writer = cast(UnionSchema, reader), cast(UnionSchema, writer)
for i, writer_branch in enumerate(writer.schemas):
compat = self.get_compatibility(reader, writer_branch)
if compat.compatibility is SchemaCompatibilityType.incompatible:
result = merge(
result,
incompatible(
SchemaIncompatibilityType.missing_union_branch,
f"reader union lacking writer type: {writer_branch.type.upper()}",
location + [str(i)],
),
)
return result
raise AvroRuntimeException(f"Unknown schema type: {reader.type}")
if writer.type == SchemaType.UNION:
writer = cast(UnionSchema, writer)
for s in writer.schemas:
result = merge(result, self.get_compatibility(reader, s))
return result
if reader.type in {SchemaType.NULL, SchemaType.BOOLEAN, SchemaType.INT}:
return merge(result, type_mismatch(reader, writer, location))
if reader.type == SchemaType.LONG:
if writer.type == SchemaType.INT:
return result
return merge(result, type_mismatch(reader, writer, location))
if reader.type == SchemaType.FLOAT:
if writer.type in {SchemaType.INT, SchemaType.LONG}:
return result
return merge(result, type_mismatch(reader, writer, location))
if reader.type == SchemaType.DOUBLE:
if writer.type in {SchemaType.INT, SchemaType.LONG, SchemaType.FLOAT}:
return result
return merge(result, type_mismatch(reader, writer, location))
if reader.type == SchemaType.BYTES:
if writer.type == SchemaType.STRING:
return result
return merge(result, type_mismatch(reader, writer, location))
if reader.type == SchemaType.STRING:
if writer.type == SchemaType.BYTES:
return result
return merge(result, type_mismatch(reader, writer, location))
if reader.type in {
SchemaType.ARRAY,
SchemaType.MAP,
SchemaType.FIXED,
SchemaType.ENUM,
SchemaType.RECORD,
}:
return merge(result, type_mismatch(reader, writer, location))
if reader.type == SchemaType.UNION:
reader = cast(UnionSchema, reader)
for reader_branch in reader.schemas:
compat = self.get_compatibility(reader_branch, writer)
if compat.compatibility is SchemaCompatibilityType.compatible:
return result
# No branch in reader compatible with writer
message = f"reader union lacking writer type {writer.type}"
return merge(
result,
incompatible(SchemaIncompatibilityType.missing_union_branch, message, location),
)
raise AvroRuntimeException(f"Unknown schema type: {reader.type}")
# pylSchemaType.INT: enable=too-many-return-statements
def check_reader_writer_record_fields(self, reader: RecordSchema, writer: RecordSchema, location: List[str]) -> SchemaCompatibilityResult:
result = CompatibleResult
for i, reader_field in enumerate(reader.fields):
reader_field = cast(Field, reader_field)
writer_field = lookup_writer_field(writer_schema=writer, reader_field=reader_field)
if writer_field is None:
if not reader_field.has_default:
if reader_field.type.type == SchemaType.ENUM and reader_field.type.props.get("default"):
result = merge(
result,
self.get_compatibility(
reader_field.type,
writer,
"type",
location + ["fields", str(i)],
),
)
else:
result = merge(
result,
incompatible(
SchemaIncompatibilityType.reader_field_missing_default_value,
reader_field.name,
location + ["fields", str(i)],
),
)
else:
result = merge(
result,
self.get_compatibility(
reader_field.type,
writer_field.type,
"type",
location + ["fields", str(i)],
),
)
return result
def type_mismatch(reader: Schema, writer: Schema, location: List[str]) -> SchemaCompatibilityResult:
message = f"reader type: {reader.type} not compatible with writer type: {writer.type}"
return incompatible(SchemaIncompatibilityType.type_mismatch, message, location)
def check_schema_names(reader: NamedSchema, writer: NamedSchema, location: List[str]) -> SchemaCompatibilityResult:
result = CompatibleResult
if not schema_name_equals(reader, writer):
message = f"expected: {writer.fullname}"
result = incompatible(SchemaIncompatibilityType.name_mismatch, message, location + ["name"])
return result
def check_fixed_size(reader: FixedSchema, writer: FixedSchema, location: List[str]) -> SchemaCompatibilityResult:
result = CompatibleResult
actual = reader.size
expected = writer.size
if actual != expected:
message = f"expected: {expected}, found: {actual}"
result = incompatible(
SchemaIncompatibilityType.fixed_size_mismatch,
message,
location + ["size"],
)
return result
def check_reader_enum_contains_writer_enum(reader: EnumSchema, writer: EnumSchema, location: List[str]) -> SchemaCompatibilityResult:
result = CompatibleResult
writer_symbols, reader_symbols = set(writer.symbols), set(reader.symbols)
extra_symbols = writer_symbols.difference(reader_symbols)
if extra_symbols:
default = reader.props.get("default")
if default and default in reader_symbols:
result = CompatibleResult
else:
result = incompatible(
SchemaIncompatibilityType.missing_enum_symbols,
str(extra_symbols),
location + ["symbols"],
)
return result
def incompatible(incompat_type: SchemaIncompatibilityType, message: str, location: List[str]) -> SchemaCompatibilityResult:
locations = "/".join(location)
if len(location) > 1:
locations = locations[1:]
ret = SchemaCompatibilityResult(
compatibility=SchemaCompatibilityType.incompatible,
incompatibilities=[incompat_type],
locations={locations},
messages={message},
)
return ret
def schema_name_equals(reader: NamedSchema, writer: NamedSchema) -> bool:
aliases = reader.props.get("aliases")
return (reader.name == writer.name) or (isinstance(aliases, Container) and writer.fullname in aliases)
def lookup_writer_field(writer_schema: RecordSchema, reader_field: Field) -> Optional[Field]:
direct = writer_schema.fields_dict.get(reader_field.name)
if direct:
return cast(Field, direct)
aliases = reader_field.props.get("aliases")
if not isinstance(aliases, Iterable):
return None
for alias in aliases:
writer_field = writer_schema.fields_dict.get(alias)
if writer_field is not None:
return cast(Field, writer_field)
return None