blob: 629d2fdc4324b83343af0b5a2df823c934b5df9c [file] [log] [blame]
#!/usr/bin/env python3
# -*- mode: python -*-
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Protocol implementation.
"""
import hashlib
import json
import logging
from avro import schema
ImmutableDict = schema.ImmutableDict
# ------------------------------------------------------------------------------
# Constants
# Allowed top-level schemas in a protocol:
VALID_TYPE_SCHEMA_TYPES = frozenset(['enum', 'record', 'error', 'fixed'])
# ------------------------------------------------------------------------------
# Exceptions
class ProtocolParseException(schema.AvroException):
"""Error while parsing a JSON protocol descriptor."""
pass
# ------------------------------------------------------------------------------
# Base Classes
class Protocol(object):
"""An application protocol."""
@staticmethod
def _ParseTypeDesc(type_desc, names):
type_schema = schema.SchemaFromJSONData(type_desc, names=names)
if type_schema.type not in VALID_TYPE_SCHEMA_TYPES:
raise ProtocolParseException(
'Invalid type %r in protocol %r: '
'protocols can only declare types %s.'
% (type_schema, avro_name, ','.join(VALID_TYPE_SCHEMA_TYPES)))
return type_schema
@staticmethod
def _ParseMessageDesc(name, message_desc, names):
"""Parses a protocol message descriptor.
Args:
name: Name of the message.
message_desc: Descriptor of the message.
names: Tracker of the named Avro schema.
Returns:
The parsed protocol message.
Raises:
ProtocolParseException: if the descriptor is invalid.
"""
request_desc = message_desc.get('request')
if request_desc is None:
raise ProtocolParseException(
'Invalid message descriptor with no "request": %r.' % message_desc)
request_schema = Message._ParseRequestFromJSONDesc(
request_desc=request_desc,
names=names,
)
response_desc = message_desc.get('response')
if response_desc is None:
raise ProtocolParseException(
'Invalid message descriptor with no "response": %r.' % message_desc)
response_schema = Message._ParseResponseFromJSONDesc(
response_desc=response_desc,
names=names,
)
# Errors are optional:
errors_desc = message_desc.get('errors', tuple())
error_union_schema = Message._ParseErrorsFromJSONDesc(
errors_desc=errors_desc,
names=names,
)
return Message(
name=name,
request=request_schema,
response=response_schema,
errors=error_union_schema,
)
@staticmethod
def _ParseMessageDescMap(message_desc_map, names):
for name, message_desc in message_desc_map.items():
yield Protocol._ParseMessageDesc(
name=name,
message_desc=message_desc,
names=names,
)
def __init__(
self,
name,
namespace=None,
types=tuple(),
messages=tuple(),
):
"""Initializes a new protocol object.
Args:
name: Protocol name (absolute or relative).
namespace: Optional explicit namespace (if name is relative).
types: Collection of types in the protocol.
messages: Collection of messages in the protocol.
"""
self._avro_name = schema.Name(name=name, namespace=namespace)
self._fullname = self._avro_name.fullname
self._name = self._avro_name.simple_name
self._namespace = self._avro_name.namespace
self._props = {}
self._props['name'] = self._name
if self._namespace:
self._props['namespace'] = self._namespace
self._names = schema.Names(default_namespace=self._namespace)
self._types = tuple(types)
# Map: type full name -> type schema
self._type_map = (
ImmutableDict((type.fullname, type) for type in self._types))
# This assertion cannot fail unless we don't track named schemas properly:
assert (len(self._types) == len(self._type_map)), (
'Type list %r does not match type map: %r'
% (self._types, self._type_map))
# TODO: set props['types']
self._messages = tuple(messages)
# Map: message name -> Message
# Note that message names are simple names unique within the protocol.
self._message_map = ImmutableDict(
items=((message.name, message) for message in self._messages))
if len(self._messages) != len(self._message_map):
raise ProtocolParseException(
'Invalid protocol %s with duplicate message name: %r'
% (self._avro_name, self._messages))
# TODO: set props['messages']
self._md5 = hashlib.md5(str(self).encode('utf-8')).digest()
@property
def name(self):
"""Returns: the simple name of the protocol."""
return self._name
@property
def namespace(self):
"""Returns: the namespace this protocol belongs to."""
return self._namespace
@property
def fullname(self):
"""Returns: the fully qualified name of this protocol."""
return self._fullname
@property
def types(self):
"""Returns: the collection of types declared in this protocol."""
return self._types
@property
def type_map(self):
"""Returns: the map of types in this protocol, indexed by their full name."""
return self._type_map
@property
def messages(self):
"""Returns: the collection of messages declared in this protocol."""
return self._messages
@property
def message_map(self):
"""Returns: the map of messages in this protocol, indexed by their name."""
return self._message_map
@property
def md5(self):
return self._md5
@property
def props(self):
return self._props
def to_json(self):
to_dump = {}
to_dump['protocol'] = self.name
names = schema.Names(default_namespace=self.namespace)
if self.namespace:
to_dump['namespace'] = self.namespace
if self.types:
to_dump['types'] = [ t.to_json(names) for t in self.types ]
if self.messages:
messages_dict = {}
for name, body in self.message_map.items():
messages_dict[name] = body.to_json(names)
to_dump['messages'] = messages_dict
return to_dump
def __str__(self):
return json.dumps(self.to_json())
def __eq__(self, that):
to_cmp = json.loads(str(self))
return to_cmp == json.loads(str(that))
# ------------------------------------------------------------------------------
class Message(object):
"""A Protocol message."""
@staticmethod
def _ParseRequestFromJSONDesc(request_desc, names):
"""Parses the request descriptor of a protocol message.
Args:
request_desc: Descriptor of the message request.
This is a list of fields that defines an unnamed record.
names: Tracker for named Avro schemas.
Returns:
The parsed request schema, as an unnamed record.
"""
fields = schema.RecordSchema._MakeFieldList(request_desc, names=names)
return schema.RecordSchema(
name=None,
namespace=None,
fields=fields,
names=names,
record_type=schema.REQUEST,
)
@staticmethod
def _ParseResponseFromJSONDesc(response_desc, names):
"""Parses the response descriptor of a protocol message.
Args:
response_desc: Descriptor of the message response.
This is an arbitrary Avro schema descriptor.
Returns:
The parsed response schema.
"""
return schema.SchemaFromJSONData(response_desc, names=names)
@staticmethod
def _ParseErrorsFromJSONDesc(errors_desc, names):
"""Parses the errors descriptor of a protocol message.
Args:
errors_desc: Descriptor of the errors thrown by the protocol message.
This is a list of error types understood as an implicit union.
Each error type is an arbitrary Avro schema.
names: Tracker for named Avro schemas.
Returns:
The parsed ErrorUnionSchema.
"""
error_union_desc = {
'type': schema.ERROR_UNION,
'declared_errors': errors_desc,
}
return schema.SchemaFromJSONData(error_union_desc, names=names)
def __init__(self, name, request, response, errors=None):
self._name = name
self._props = {}
# TODO: set properties
self._request = request
self._response = response
self._errors = errors
@property
def name(self):
return self._name
@property
def request(self):
return self._request
@property
def response(self):
return self._response
@property
def errors(self):
return self._errors
def props(self):
return self._props
def __str__(self):
return json.dumps(self.to_json())
def to_json(self, names=None):
if names is None:
names = schema.Names()
to_dump = {}
to_dump['request'] = self.request.to_json(names)
to_dump['response'] = self.response.to_json(names)
if self.errors:
to_dump['errors'] = self.errors.to_json(names)
return to_dump
def __eq__(self, that):
return self.name == that.name and self.props == that.props
# ------------------------------------------------------------------------------
def ProtocolFromJSONData(json_data):
"""Builds an Avro Protocol from its JSON descriptor.
Args:
json_data: JSON data representing the descriptor of the Avro protocol.
Returns:
The Avro Protocol parsed from the JSON descriptor.
Raises:
ProtocolParseException: if the descriptor is invalid.
"""
if type(json_data) != dict:
raise ProtocolParseException(
'Invalid JSON descriptor for an Avro protocol: %r' % json_data)
name = json_data.get('protocol')
if name is None:
raise ProtocolParseException(
'Invalid protocol descriptor with no "name": %r' % json_data)
# Namespace is optional
namespace = json_data.get('namespace')
avro_name = schema.Name(name=name, namespace=namespace)
names = schema.Names(default_namespace=avro_name.namespace)
type_desc_list = json_data.get('types', tuple())
types = tuple(map(
lambda desc: Protocol._ParseTypeDesc(desc, names=names),
type_desc_list))
message_desc_map = json_data.get('messages', dict())
messages = tuple(Protocol._ParseMessageDescMap(message_desc_map, names=names))
return Protocol(
name=name,
namespace=namespace,
types=types,
messages=messages,
)
def Parse(json_string):
"""Constructs a Protocol from its JSON descriptor in text form.
Args:
json_string: String representation of the JSON descriptor of the protocol.
Returns:
The parsed protocol.
Raises:
ProtocolParseException: on JSON parsing error,
or if the JSON descriptor is invalid.
"""
try:
json_data = json.loads(json_string)
except Exception as exn:
raise ProtocolParseException(
'Error parsing protocol from JSON: %r. '
'Error message: %r.'
% (json_string, exn))
return ProtocolFromJSONData(json_data)