blob: 792f76432ccf50c852a5e5713df1b04133a11d40 [file] [log] [blame]
#!/usr/bin/env python3
# -*- mode: python -*-
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RPC/IPC support."""
import abc
import http.client
import http.server
import io
import logging
import os
import socketserver
from avro import io as avro_io
from avro import protocol
from avro import schema
logger = logging.getLogger(__name__)
# ------------------------------------------------------------------------------
# Constants
def LoadResource(name):
dir_path = os.path.dirname(__file__)
rsrc_path = os.path.join(dir_path, name)
with open(rsrc_path, 'r') as f:
return f.read()
# Handshake schema is pulled in during build
HANDSHAKE_REQUEST_SCHEMA_JSON = LoadResource('HandshakeRequest.avsc')
HANDSHAKE_RESPONSE_SCHEMA_JSON = LoadResource('HandshakeResponse.avsc')
HANDSHAKE_REQUEST_SCHEMA = schema.Parse(HANDSHAKE_REQUEST_SCHEMA_JSON)
HANDSHAKE_RESPONSE_SCHEMA = schema.Parse(HANDSHAKE_RESPONSE_SCHEMA_JSON)
HANDSHAKE_REQUESTOR_WRITER = avro_io.DatumWriter(HANDSHAKE_REQUEST_SCHEMA)
HANDSHAKE_REQUESTOR_READER = avro_io.DatumReader(HANDSHAKE_RESPONSE_SCHEMA)
HANDSHAKE_RESPONDER_WRITER = avro_io.DatumWriter(HANDSHAKE_RESPONSE_SCHEMA)
HANDSHAKE_RESPONDER_READER = avro_io.DatumReader(HANDSHAKE_REQUEST_SCHEMA)
META_SCHEMA = schema.Parse('{"type": "map", "values": "bytes"}')
META_WRITER = avro_io.DatumWriter(META_SCHEMA)
META_READER = avro_io.DatumReader(META_SCHEMA)
SYSTEM_ERROR_SCHEMA = schema.Parse('["string"]')
AVRO_RPC_MIME = 'avro/binary'
# protocol cache
# Map: remote name -> remote MD5 hash
_REMOTE_HASHES = {}
# Decoder/encoder for a 32 bits big-endian integer.
UINT32_BE = avro_io.STRUCT_INT
# Default size of the buffers use to frame messages:
BUFFER_SIZE = 8192
# ------------------------------------------------------------------------------
# Exceptions
class AvroRemoteException(schema.AvroException):
"""
Raised when an error message is sent by an Avro requestor or responder.
"""
def __init__(self, fail_msg=None):
schema.AvroException.__init__(self, fail_msg)
class ConnectionClosedException(schema.AvroException):
pass
# ------------------------------------------------------------------------------
# Base IPC Classes (Requestor/Responder)
class BaseRequestor(object, metaclass=abc.ABCMeta):
"""Base class for the client side of a protocol interaction."""
def __init__(self, local_protocol, transceiver):
"""Initializes a new requestor object.
Args:
local_protocol: Avro Protocol describing the messages sent and received.
transceiver: Transceiver instance to channel messages through.
"""
self._local_protocol = local_protocol
self._transceiver = transceiver
self._remote_protocol = None
self._remote_hash = None
self._send_protocol = None
@property
def local_protocol(self):
"""Returns: the Avro Protocol describing the messages sent and received."""
return self._local_protocol
@property
def transceiver(self):
"""Returns: the underlying channel used by this requestor."""
return self._transceiver
@abc.abstractmethod
def _IssueRequest(self, call_request, message_name, request_datum):
"""TODO: Document this method.
Args:
call_request: ???
message_name: Name of the message.
request_datum: ???
Returns:
???
"""
raise Error('Abstract method')
def Request(self, message_name, request_datum):
"""Writes a request message and reads a response or error message.
Args:
message_name: Name of the IPC method.
request_datum: IPC request.
Returns:
The IPC response.
"""
# build handshake and call request
buffer_writer = io.BytesIO()
buffer_encoder = avro_io.BinaryEncoder(buffer_writer)
self._WriteHandshakeRequest(buffer_encoder)
self._WriteCallRequest(message_name, request_datum, buffer_encoder)
# send the handshake and call request; block until call response
call_request = buffer_writer.getvalue()
return self._IssueRequest(call_request, message_name, request_datum)
def _WriteHandshakeRequest(self, encoder):
"""Emits the handshake request.
Args:
encoder: Encoder to write the handshake request into.
"""
local_hash = self._local_protocol.md5
# if self._remote_hash is None:
# remote_name = self.transceiver.remote_name
# self._remote_hash = _REMOTE_HASHES.get(remote_name)
if self._remote_hash is None:
self._remote_hash = local_hash
self._remote_protocol = self._local_protocol
request_datum = {
'clientHash': local_hash,
'serverHash': self._remote_hash,
}
if self._send_protocol:
request_datum['clientProtocol'] = str(self._local_protocol)
logger.info('Sending handshake request: %s', request_datum)
HANDSHAKE_REQUESTOR_WRITER.write(request_datum, encoder)
def _WriteCallRequest(self, message_name, request_datum, encoder):
"""
The format of a call request is:
* request metadata, a map with values of type bytes
* the message name, an Avro string, followed by
* the message parameters. Parameters are serialized according to
the message's request declaration.
"""
# request metadata (not yet implemented)
request_metadata = {}
META_WRITER.write(request_metadata, encoder)
# Identify message to send:
message = self.local_protocol.message_map.get(message_name)
if message is None:
raise schema.AvroException('Unknown message: %s' % message_name)
encoder.write_utf8(message.name)
# message parameters
self._WriteRequest(message.request, request_datum, encoder)
def _WriteRequest(self, request_schema, request_datum, encoder):
logger.info('writing request: %s', request_datum)
datum_writer = avro_io.DatumWriter(request_schema)
datum_writer.write(request_datum, encoder)
def _ReadHandshakeResponse(self, decoder):
"""Reads and processes the handshake response message.
Args:
decoder: Decoder to read messages from.
Returns:
call-response exists (boolean) ???
Raises:
schema.AvroException on ???
"""
handshake_response = HANDSHAKE_REQUESTOR_READER.read(decoder)
logger.info('Processing handshake response: %s', handshake_response)
match = handshake_response['match']
if match == 'BOTH':
# Both client and server protocol hashes match:
self._send_protocol = False
return True
elif match == 'CLIENT':
# Client's side hash mismatch:
self._remote_protocol = \
protocol.Parse(handshake_response['serverProtocol'])
self._remote_hash = handshake_response['serverHash']
self._send_protocol = False
return True
elif match == 'NONE':
# Neither client nor server match:
self._remote_protocol = \
protocol.Parse(handshake_response['serverProtocol'])
self._remote_hash = handshake_response['serverHash']
self._send_protocol = True
return False
else:
raise schema.AvroException('handshake_response.match=%r' % match)
def _ReadCallResponse(self, message_name, decoder):
"""Reads and processes a method call response.
The format of a call response is:
- response metadata, a map with values of type bytes
- a one-byte error flag boolean, followed by either:
- if the error flag is false,
the message response, serialized per the message's response schema.
- if the error flag is true,
the error, serialized per the message's error union schema.
Args:
message_name:
decoder:
Returns:
???
Raises:
schema.AvroException on ???
"""
# response metadata
response_metadata = META_READER.read(decoder)
# remote response schema
remote_message_schema = self._remote_protocol.message_map.get(message_name)
if remote_message_schema is None:
raise schema.AvroException('Unknown remote message: %s' % message_name)
# local response schema
local_message_schema = self._local_protocol.message_map.get(message_name)
if local_message_schema is None:
raise schema.AvroException('Unknown local message: %s' % message_name)
# error flag
if not decoder.read_boolean():
writer_schema = remote_message_schema.response
reader_schema = local_message_schema.response
return self._ReadResponse(writer_schema, reader_schema, decoder)
else:
writer_schema = remote_message_schema.errors
reader_schema = local_message_schema.errors
raise self._ReadError(writer_schema, reader_schema, decoder)
def _ReadResponse(self, writer_schema, reader_schema, decoder):
datum_reader = avro_io.DatumReader(writer_schema, reader_schema)
result = datum_reader.read(decoder)
return result
def _ReadError(self, writer_schema, reader_schema, decoder):
datum_reader = avro_io.DatumReader(writer_schema, reader_schema)
return AvroRemoteException(datum_reader.read(decoder))
class Requestor(BaseRequestor):
"""Concrete requestor implementation."""
def _IssueRequest(self, call_request, message_name, request_datum):
call_response = self.transceiver.Transceive(call_request)
# process the handshake and call response
buffer_decoder = avro_io.BinaryDecoder(io.BytesIO(call_response))
call_response_exists = self._ReadHandshakeResponse(buffer_decoder)
if call_response_exists:
return self._ReadCallResponse(message_name, buffer_decoder)
else:
return self.Request(message_name, request_datum)
# ------------------------------------------------------------------------------
class Responder(object, metaclass=abc.ABCMeta):
"""Base class for the server side of a protocol interaction."""
def __init__(self, local_protocol):
self._local_protocol = local_protocol
self._local_hash = self._local_protocol.md5
self._protocol_cache = {}
self.set_protocol_cache(self._local_hash, self._local_protocol)
@property
def local_protocol(self):
return self._local_protocol
# utility functions to manipulate protocol cache
def get_protocol_cache(self, hash):
return self._protocol_cache.get(hash)
def set_protocol_cache(self, hash, protocol):
self._protocol_cache[hash] = protocol
def Respond(self, call_request):
"""Entry point to process one procedure call.
Args:
call_request: Serialized procedure call request.
Returns:
Serialized procedure call response.
Raises:
???
"""
buffer_reader = io.BytesIO(call_request)
buffer_decoder = avro_io.BinaryDecoder(buffer_reader)
buffer_writer = io.BytesIO()
buffer_encoder = avro_io.BinaryEncoder(buffer_writer)
error = None
response_metadata = {}
try:
remote_protocol = self._ProcessHandshake(buffer_decoder, buffer_encoder)
# handshake failure
if remote_protocol is None:
return buffer_writer.getvalue()
# read request using remote protocol
request_metadata = META_READER.read(buffer_decoder)
remote_message_name = buffer_decoder.read_utf8()
# get remote and local request schemas so we can do
# schema resolution (one fine day)
remote_message = remote_protocol.message_map.get(remote_message_name)
if remote_message is None:
fail_msg = 'Unknown remote message: %s' % remote_message_name
raise schema.AvroException(fail_msg)
local_message = self.local_protocol.message_map.get(remote_message_name)
if local_message is None:
fail_msg = 'Unknown local message: %s' % remote_message_name
raise schema.AvroException(fail_msg)
writer_schema = remote_message.request
reader_schema = local_message.request
request = self._ReadRequest(writer_schema, reader_schema, buffer_decoder)
logger.info('Processing request: %r', request)
# perform server logic
try:
response = self.Invoke(local_message, request)
except AvroRemoteException as exn:
error = exn
except Exception as exn:
error = AvroRemoteException(str(exn))
# write response using local protocol
META_WRITER.write(response_metadata, buffer_encoder)
buffer_encoder.write_boolean(error is not None)
if error is None:
writer_schema = local_message.response
self._WriteResponse(writer_schema, response, buffer_encoder)
else:
writer_schema = local_message.errors
self._WriteError(writer_schema, error, buffer_encoder)
except schema.AvroException as exn:
error = AvroRemoteException(str(exn))
buffer_encoder = avro_io.BinaryEncoder(io.StringIO())
META_WRITER.write(response_metadata, buffer_encoder)
buffer_encoder.write_boolean(True)
self._WriteError(SYSTEM_ERROR_SCHEMA, error, buffer_encoder)
return buffer_writer.getvalue()
def _ProcessHandshake(self, decoder, encoder):
"""Processes an RPC handshake.
Args:
decoder: Where to read from.
encoder: Where to write to.
Returns:
The requested Protocol.
"""
handshake_request = HANDSHAKE_RESPONDER_READER.read(decoder)
logger.info('Processing handshake request: %s', handshake_request)
# determine the remote protocol
client_hash = handshake_request.get('clientHash')
client_protocol = handshake_request.get('clientProtocol')
remote_protocol = self.get_protocol_cache(client_hash)
if remote_protocol is None and client_protocol is not None:
remote_protocol = protocol.Parse(client_protocol)
self.set_protocol_cache(client_hash, remote_protocol)
# evaluate remote's guess of the local protocol
server_hash = handshake_request.get('serverHash')
handshake_response = {}
if self._local_hash == server_hash:
if remote_protocol is None:
handshake_response['match'] = 'NONE'
else:
handshake_response['match'] = 'BOTH'
else:
if remote_protocol is None:
handshake_response['match'] = 'NONE'
else:
handshake_response['match'] = 'CLIENT'
if handshake_response['match'] != 'BOTH':
handshake_response['serverProtocol'] = str(self.local_protocol)
handshake_response['serverHash'] = self._local_hash
logger.info('Handshake response: %s', handshake_response)
HANDSHAKE_RESPONDER_WRITER.write(handshake_response, encoder)
return remote_protocol
@abc.abstractmethod
def Invoke(self, local_message, request):
"""Processes one procedure call.
Args:
local_message: Avro message specification.
request: Call request.
Returns:
Call response.
Raises:
???
"""
raise Error('abtract method')
def _ReadRequest(self, writer_schema, reader_schema, decoder):
datum_reader = avro_io.DatumReader(writer_schema, reader_schema)
return datum_reader.read(decoder)
def _WriteResponse(self, writer_schema, response_datum, encoder):
datum_writer = avro_io.DatumWriter(writer_schema)
datum_writer.write(response_datum, encoder)
def _WriteError(self, writer_schema, error_exception, encoder):
datum_writer = avro_io.DatumWriter(writer_schema)
datum_writer.write(str(error_exception), encoder)
# ------------------------------------------------------------------------------
# Framed message
class FramedReader(object):
"""Wrapper around a file-like object to read framed data."""
def __init__(self, reader):
self._reader = reader
def Read(self):
"""Reads one message from the configured reader.
Returns:
The message, as bytes.
"""
message = io.BytesIO()
# Read and append frames until we encounter a 0-size frame:
while self._ReadFrame(message) > 0: pass
return message.getvalue()
def _ReadFrame(self, message):
"""Reads and appends one frame into the given message bytes.
Args:
message: Message to append the frame to.
Returns:
Size of the frame that was read.
The empty frame (size 0) indicates the end of a message.
"""
frame_size = self._ReadInt32()
remaining = frame_size
while remaining > 0:
data_bytes = self._reader.read(remaining)
if len(data_bytes) == 0:
raise ConnectionClosedException(
'FramedReader: expecting %d more bytes in frame of size %d, got 0.'
% (remaining, frame_size))
message.write(data_bytes)
remaining -= len(data_bytes)
return frame_size
def _ReadInt32(self):
encoded = self._reader.read(UINT32_BE.size)
if len(encoded) != UINT32_BE.size:
raise ConnectionClosedException('Invalid header: %r' % encoded)
return UINT32_BE.unpack(encoded)[0]
class FramedWriter(object):
"""Wrapper around a file-like object to write framed data."""
def __init__(self, writer):
self._writer = writer
def Write(self, message):
"""Writes a message.
Message is chunked into sequences of frames terminated by an empty frame.
Args:
message: Message to write, as bytes.
"""
while len(message) > 0:
chunk_size = max(BUFFER_SIZE, len(message))
chunk = message[:chunk_size]
self._WriteBuffer(chunk)
message = message[chunk_size:]
# A message is always terminated by a zero-length buffer.
self._WriteUnsignedInt32(0)
def _WriteBuffer(self, chunk):
self._WriteUnsignedInt32(len(chunk))
self._writer.write(chunk)
def _WriteUnsignedInt32(self, uint32):
self._writer.write(UINT32_BE.pack(uint32))
# ------------------------------------------------------------------------------
# Transceiver (send/receive channel)
class Transceiver(object, metaclass=abc.ABCMeta):
@abc.abstractproperty
def remote_name(self):
pass
@abc.abstractmethod
def ReadMessage(self):
"""Reads a single message from the channel.
Blocks until a message can be read.
Returns:
The message read from the channel.
"""
pass
@abc.abstractmethod
def WriteMessage(self, message):
"""Writes a message into the channel.
Blocks until the message has been written.
Args:
message: Message to write.
"""
pass
def Transceive(self, request):
"""Processes a single request-reply interaction.
Synchronous request-reply interaction.
Args:
request: Request message.
Returns:
The reply message.
"""
self.WriteMessage(request)
result = self.ReadMessage()
return result
def Close(self):
"""Closes this transceiver."""
pass
class HTTPTransceiver(Transceiver):
"""HTTP-based transceiver implementation."""
def __init__(self, host, port, req_resource='/'):
"""Initializes a new HTTP transceiver.
Args:
host: Name or IP address of the remote host to interact with.
port: Port the remote server is listening on.
req_resource: Optional HTTP resource path to use, '/' by default.
"""
self._req_resource = req_resource
self._conn = http.client.HTTPConnection(host, port)
self._conn.connect()
self._remote_name = self._conn.sock.getsockname()
@property
def remote_name(self):
return self._remote_name
def ReadMessage(self):
response = self._conn.getresponse()
response_reader = FramedReader(response)
framed_message = response_reader.Read()
response.read() # ensure we're ready for subsequent requests
return framed_message
def WriteMessage(self, message):
req_method = 'POST'
req_headers = {'Content-Type': AVRO_RPC_MIME}
bio = io.BytesIO()
req_body_buffer = FramedWriter(bio)
req_body_buffer.Write(message)
req_body = bio.getvalue()
self._conn.request(req_method, self._req_resource, req_body, req_headers)
def Close(self):
self._conn.close()
self._conn = None
# ------------------------------------------------------------------------------
# Server Implementations
def _MakeHandlerClass(responder):
class AvroHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
def do_POST(self):
reader = FramedReader(self.rfile)
call_request = reader.Read()
logger.info('Serialized request: %r', call_request)
call_response = responder.Respond(call_request)
logger.info('Serialized response: %r', call_response)
self.send_response(200)
self.send_header('Content-type', AVRO_RPC_MIME)
self.end_headers()
framed_writer = FramedWriter(self.wfile)
framed_writer.Write(call_response)
self.wfile.flush()
logger.info('Response sent')
return AvroHTTPRequestHandler
class MultiThreadedHTTPServer(
socketserver.ThreadingMixIn,
http.server.HTTPServer,
):
"""Multi-threaded HTTP server."""
pass
class AvroIpcHttpServer(MultiThreadedHTTPServer):
"""Avro IPC server implemented on top of an HTTP server."""
def __init__(self, interface, port, responder):
"""Initializes a new Avro IPC server.
Args:
interface: Interface the server listens on, eg. 'localhost' or '0.0.0.0'.
port: TCP port the server listens on, eg. 8000.
responder: Responder implementation to handle RPCs.
"""
super(AvroIpcHttpServer, self).__init__(
server_address=(interface, port),
RequestHandlerClass=_MakeHandlerClass(responder),
)
if __name__ == '__main__':
raise Exception('Not a standalone module')