blob: 550d69664f361d4b45170d1a47475eef6d8e3e10 [file] [log] [blame]
#!/usr/bin/env python
##
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Support for inter-process calls."""
from __future__ import absolute_import, division, print_function
import io
import os
from struct import Struct
import avro.io
from avro import protocol, schema
try:
import httplib # type: ignore
except ImportError:
import http.client as httplib # type: ignore
try:
unicode
except NameError:
unicode = str
def _load(name):
dir_path = os.path.dirname(__file__)
rsrc_path = os.path.join(dir_path, name)
with open(rsrc_path, 'r') as f:
return f.read()
HANDSHAKE_REQUEST_SCHEMA_JSON = _load('HandshakeRequest.avsc')
HANDSHAKE_RESPONSE_SCHEMA_JSON = _load('HandshakeResponse.avsc')
HANDSHAKE_REQUEST_SCHEMA = schema.parse(HANDSHAKE_REQUEST_SCHEMA_JSON)
HANDSHAKE_RESPONSE_SCHEMA = schema.parse(HANDSHAKE_RESPONSE_SCHEMA_JSON)
HANDSHAKE_REQUESTOR_WRITER = avro.io.DatumWriter(HANDSHAKE_REQUEST_SCHEMA)
HANDSHAKE_REQUESTOR_READER = avro.io.DatumReader(HANDSHAKE_RESPONSE_SCHEMA)
HANDSHAKE_RESPONDER_WRITER = avro.io.DatumWriter(HANDSHAKE_RESPONSE_SCHEMA)
HANDSHAKE_RESPONDER_READER = avro.io.DatumReader(HANDSHAKE_REQUEST_SCHEMA)
META_SCHEMA = schema.parse('{"type": "map", "values": "bytes"}')
META_WRITER = avro.io.DatumWriter(META_SCHEMA)
META_READER = avro.io.DatumReader(META_SCHEMA)
SYSTEM_ERROR_SCHEMA = schema.parse('["string"]')
# protocol cache
REMOTE_HASHES = {}
REMOTE_PROTOCOLS = {}
BIG_ENDIAN_INT_STRUCT = Struct('!I')
BUFFER_HEADER_LENGTH = 4
BUFFER_SIZE = 8192
#
# Exceptions
#
class AvroRemoteException(schema.AvroException):
"""
Raised when an error message is sent by an Avro requestor or responder.
"""
def __init__(self, fail_msg=None):
schema.AvroException.__init__(self, fail_msg)
class ConnectionClosedException(schema.AvroException):
pass
#
# Base IPC Classes (Requestor/Responder)
#
class BaseRequestor(object):
"""Base class for the client side of a protocol interaction."""
def __init__(self, local_protocol, transceiver):
self._local_protocol = local_protocol
self._transceiver = transceiver
self._remote_protocol = None
self._remote_hash = None
self._send_protocol = None
# read-only properties
local_protocol = property(lambda self: self._local_protocol)
transceiver = property(lambda self: self._transceiver)
# read/write properties
def set_remote_protocol(self, new_remote_protocol):
self._remote_protocol = new_remote_protocol
REMOTE_PROTOCOLS[self.transceiver.remote_name] = self.remote_protocol
remote_protocol = property(lambda self: self._remote_protocol,
set_remote_protocol)
def set_remote_hash(self, new_remote_hash):
self._remote_hash = new_remote_hash
REMOTE_HASHES[self.transceiver.remote_name] = self.remote_hash
remote_hash = property(lambda self: self._remote_hash, set_remote_hash)
def set_send_protocol(self, new_send_protocol):
self._send_protocol = new_send_protocol
send_protocol = property(lambda self: self._send_protocol, set_send_protocol)
def request(self, message_name, request_datum):
"""
Writes a request message and reads a response or error message.
"""
# build handshake and call request
buffer_writer = io.BytesIO()
buffer_encoder = avro.io.BinaryEncoder(buffer_writer)
self.write_handshake_request(buffer_encoder)
self.write_call_request(message_name, request_datum, buffer_encoder)
# send the handshake and call request; block until call response
call_request = buffer_writer.getvalue()
return self.issue_request(call_request, message_name, request_datum)
def write_handshake_request(self, encoder):
local_hash = self.local_protocol.md5
remote_name = self.transceiver.remote_name
remote_hash = REMOTE_HASHES.get(remote_name)
if remote_hash is None:
remote_hash = local_hash
self.remote_protocol = self.local_protocol
request_datum = {}
request_datum['clientHash'] = local_hash
request_datum['serverHash'] = remote_hash
if self.send_protocol:
request_datum['clientProtocol'] = unicode(self.local_protocol)
HANDSHAKE_REQUESTOR_WRITER.write(request_datum, encoder)
def write_call_request(self, message_name, request_datum, encoder):
"""
The format of a call request is:
* request metadata, a map with values of type bytes
* the message name, an Avro string, followed by
* the message parameters. Parameters are serialized according to
the message's request declaration.
"""
# request metadata (not yet implemented)
request_metadata = {}
META_WRITER.write(request_metadata, encoder)
# message name
message = self.local_protocol.messages.get(message_name)
if message is None:
raise schema.AvroException('Unknown message: %s' % message_name)
encoder.write_utf8(message.name)
# message parameters
self.write_request(message.request, request_datum, encoder)
def write_request(self, request_schema, request_datum, encoder):
datum_writer = avro.io.DatumWriter(request_schema)
datum_writer.write(request_datum, encoder)
def read_handshake_response(self, decoder):
handshake_response = HANDSHAKE_REQUESTOR_READER.read(decoder)
match = handshake_response.get('match')
if match == 'BOTH':
self.send_protocol = False
return True
elif match == 'CLIENT':
if self.send_protocol:
raise schema.AvroException('Handshake failure.')
self.remote_protocol = protocol.parse(
handshake_response.get('serverProtocol'))
self.remote_hash = handshake_response.get('serverHash')
self.send_protocol = False
return True
elif match == 'NONE':
if self.send_protocol:
raise schema.AvroException('Handshake failure.')
self.remote_protocol = protocol.parse(
handshake_response.get('serverProtocol'))
self.remote_hash = handshake_response.get('serverHash')
self.send_protocol = True
return False
else:
raise schema.AvroException('Unexpected match: %s' % match)
def read_call_response(self, message_name, decoder):
"""
The format of a call response is:
* response metadata, a map with values of type bytes
* a one-byte error flag boolean, followed by either:
o if the error flag is false,
the message response, serialized per the message's response schema.
o if the error flag is true,
the error, serialized per the message's error union schema.
"""
# response metadata
response_metadata = META_READER.read(decoder)
# remote response schema
remote_message_schema = self.remote_protocol.messages.get(message_name)
if remote_message_schema is None:
raise schema.AvroException('Unknown remote message: %s' % message_name)
# local response schema
local_message_schema = self.local_protocol.messages.get(message_name)
if local_message_schema is None:
raise schema.AvroException('Unknown local message: %s' % message_name)
# error flag
if not decoder.read_boolean():
writers_schema = remote_message_schema.response
readers_schema = local_message_schema.response
return self.read_response(writers_schema, readers_schema, decoder)
else:
writers_schema = remote_message_schema.errors
readers_schema = local_message_schema.errors
raise self.read_error(writers_schema, readers_schema, decoder)
def read_response(self, writers_schema, readers_schema, decoder):
datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
result = datum_reader.read(decoder)
return result
def read_error(self, writers_schema, readers_schema, decoder):
datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
return AvroRemoteException(datum_reader.read(decoder))
class Requestor(BaseRequestor):
def issue_request(self, call_request, message_name, request_datum):
call_response = self.transceiver.transceive(call_request)
# process the handshake and call response
buffer_decoder = avro.io.BinaryDecoder(io.BytesIO(call_response))
call_response_exists = self.read_handshake_response(buffer_decoder)
if call_response_exists:
return self.read_call_response(message_name, buffer_decoder)
return self.request(message_name, request_datum)
class Responder(object):
"""Base class for the server side of a protocol interaction."""
def __init__(self, local_protocol):
self._local_protocol = local_protocol
self._local_hash = self.local_protocol.md5
self._protocol_cache = {}
self.set_protocol_cache(self.local_hash, self.local_protocol)
# read-only properties
local_protocol = property(lambda self: self._local_protocol)
local_hash = property(lambda self: self._local_hash)
protocol_cache = property(lambda self: self._protocol_cache)
# utility functions to manipulate protocol cache
def get_protocol_cache(self, hash):
return self.protocol_cache.get(hash)
def set_protocol_cache(self, hash, protocol):
self.protocol_cache[hash] = protocol
def respond(self, call_request):
"""
Called by a server to deserialize a request, compute and serialize
a response or error. Compare to 'handle()' in Thrift.
"""
buffer_reader = io.BytesIO(call_request)
buffer_decoder = avro.io.BinaryDecoder(buffer_reader)
buffer_writer = io.BytesIO()
buffer_encoder = avro.io.BinaryEncoder(buffer_writer)
error = None
response_metadata = {}
try:
remote_protocol = self.process_handshake(buffer_decoder, buffer_encoder)
# handshake failure
if remote_protocol is None:
return buffer_writer.getvalue()
# read request using remote protocol
request_metadata = META_READER.read(buffer_decoder)
remote_message_name = buffer_decoder.read_utf8()
# get remote and local request schemas so we can do
# schema resolution (one fine day)
remote_message = remote_protocol.messages.get(remote_message_name)
if remote_message is None:
fail_msg = 'Unknown remote message: %s' % remote_message_name
raise schema.AvroException(fail_msg)
local_message = self.local_protocol.messages.get(remote_message_name)
if local_message is None:
fail_msg = 'Unknown local message: %s' % remote_message_name
raise schema.AvroException(fail_msg)
writers_schema = remote_message.request
readers_schema = local_message.request
request = self.read_request(writers_schema, readers_schema,
buffer_decoder)
# perform server logic
try:
response = self.invoke(local_message, request)
except AvroRemoteException as e:
error = e
except Exception as e:
error = AvroRemoteException(unicode(e))
# write response using local protocol
META_WRITER.write(response_metadata, buffer_encoder)
buffer_encoder.write_boolean(error is not None)
if error is None:
writers_schema = local_message.response
self.write_response(writers_schema, response, buffer_encoder)
else:
writers_schema = local_message.errors
self.write_error(writers_schema, error, buffer_encoder)
except schema.AvroException as e:
error = AvroRemoteException(unicode(e))
buffer_encoder = avro.io.BinaryEncoder(io.BytesIO())
META_WRITER.write(response_metadata, buffer_encoder)
buffer_encoder.write_boolean(True)
self.write_error(SYSTEM_ERROR_SCHEMA, error, buffer_encoder)
return buffer_writer.getvalue()
def process_handshake(self, decoder, encoder):
handshake_request = HANDSHAKE_RESPONDER_READER.read(decoder)
handshake_response = {}
# determine the remote protocol
client_hash = handshake_request.get('clientHash')
client_protocol = handshake_request.get('clientProtocol')
remote_protocol = self.get_protocol_cache(client_hash)
if remote_protocol is None and client_protocol is not None:
remote_protocol = protocol.parse(client_protocol)
self.set_protocol_cache(client_hash, remote_protocol)
# evaluate remote's guess of the local protocol
server_hash = handshake_request.get('serverHash')
if self.local_hash == server_hash:
if remote_protocol is None:
handshake_response['match'] = 'NONE'
else:
handshake_response['match'] = 'BOTH'
else:
if remote_protocol is None:
handshake_response['match'] = 'NONE'
else:
handshake_response['match'] = 'CLIENT'
if handshake_response['match'] != 'BOTH':
handshake_response['serverProtocol'] = unicode(self.local_protocol)
handshake_response['serverHash'] = self.local_hash
HANDSHAKE_RESPONDER_WRITER.write(handshake_response, encoder)
return remote_protocol
def invoke(self, local_message, request):
"""
Aactual work done by server: cf. handler in thrift.
"""
pass
def read_request(self, writers_schema, readers_schema, decoder):
datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
return datum_reader.read(decoder)
def write_response(self, writers_schema, response_datum, encoder):
datum_writer = avro.io.DatumWriter(writers_schema)
datum_writer.write(response_datum, encoder)
def write_error(self, writers_schema, error_exception, encoder):
datum_writer = avro.io.DatumWriter(writers_schema)
datum_writer.write(unicode(error_exception), encoder)
#
# Utility classes
#
class FramedReader(object):
"""Wrapper around a file-like object to read framed data."""
def __init__(self, reader):
self._reader = reader
# read-only properties
reader = property(lambda self: self._reader)
def read_framed_message(self):
message = []
while True:
buffer = io.BytesIO()
buffer_length = self._read_buffer_length()
if buffer_length == 0:
return b''.join(message)
while buffer.tell() < buffer_length:
chunk = self.reader.read(buffer_length - buffer.tell())
if chunk == '':
raise ConnectionClosedException("Reader read 0 bytes.")
buffer.write(chunk)
message.append(buffer.getvalue())
def _read_buffer_length(self):
read = self.reader.read(BUFFER_HEADER_LENGTH)
if read == '':
raise ConnectionClosedException("Reader read 0 bytes.")
return BIG_ENDIAN_INT_STRUCT.unpack(read)[0]
class FramedWriter(object):
"""Wrapper around a file-like object to write framed data."""
def __init__(self, writer):
self._writer = writer
# read-only properties
writer = property(lambda self: self._writer)
def write_framed_message(self, message):
message_length = len(message)
total_bytes_sent = 0
while message_length - total_bytes_sent > 0:
if message_length - total_bytes_sent > BUFFER_SIZE:
buffer_length = BUFFER_SIZE
else:
buffer_length = message_length - total_bytes_sent
self.write_buffer(message[total_bytes_sent:
(total_bytes_sent + buffer_length)])
total_bytes_sent += buffer_length
# A message is always terminated by a zero-length buffer.
self.write_buffer_length(0)
def write_buffer(self, chunk):
buffer_length = len(chunk)
self.write_buffer_length(buffer_length)
self.writer.write(chunk)
def write_buffer_length(self, n):
self.writer.write(BIG_ENDIAN_INT_STRUCT.pack(n))
#
# Transceiver Implementations
#
class HTTPTransceiver(object):
"""
A simple HTTP-based transceiver implementation.
Useful for clients but not for servers
"""
def __init__(self, host, port, req_resource='/'):
self.req_resource = req_resource
self.conn = httplib.HTTPConnection(host, port)
self.conn.connect()
self.remote_name = self.conn.sock.getsockname()
def transceive(self, request):
self.write_framed_message(request)
result = self.read_framed_message()
return result
def read_framed_message(self):
response = self.conn.getresponse()
response_reader = FramedReader(response)
framed_message = response_reader.read_framed_message()
response.read() # ensure we're ready for subsequent requests
return framed_message
def write_framed_message(self, message):
req_method = 'POST'
req_headers = {'Content-Type': 'avro/binary'}
req_body_buffer = FramedWriter(io.BytesIO())
req_body_buffer.write_framed_message(message)
req_body = req_body_buffer.writer.getvalue()
self.conn.request(req_method, self.req_resource, req_body, req_headers)
def close(self):
self.conn.close()
#
# Server Implementations (none yet)
#