blob: f5fb42925494b80a6c21f2cb1c479ff1c397c5a1 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
require "net/http"
module Avro::IPC
class AvroRemoteError < Avro::AvroError; end
HANDSHAKE_REQUEST_SCHEMA = Avro::Schema.parse <<-JSON
{
"type": "record",
"name": "HandshakeRequest", "namespace":"org.apache.avro.ipc",
"fields": [
{"name": "clientHash",
"type": {"type": "fixed", "name": "MD5", "size": 16}},
{"name": "clientProtocol", "type": ["null", "string"]},
{"name": "serverHash", "type": "MD5"},
{"name": "meta", "type": ["null", {"type": "map", "values": "bytes"}]}
]
}
JSON
HANDSHAKE_RESPONSE_SCHEMA = Avro::Schema.parse <<-JSON
{
"type": "record",
"name": "HandshakeResponse", "namespace": "org.apache.avro.ipc",
"fields": [
{"name": "match",
"type": {"type": "enum", "name": "HandshakeMatch",
"symbols": ["BOTH", "CLIENT", "NONE"]}},
{"name": "serverProtocol", "type": ["null", "string"]},
{"name": "serverHash",
"type": ["null", {"type": "fixed", "name": "MD5", "size": 16}]},
{"name": "meta",
"type": ["null", {"type": "map", "values": "bytes"}]}
]
}
JSON
HANDSHAKE_REQUESTOR_WRITER = Avro::IO::DatumWriter.new(HANDSHAKE_REQUEST_SCHEMA)
HANDSHAKE_REQUESTOR_READER = Avro::IO::DatumReader.new(HANDSHAKE_RESPONSE_SCHEMA)
HANDSHAKE_RESPONDER_WRITER = Avro::IO::DatumWriter.new(HANDSHAKE_RESPONSE_SCHEMA)
HANDSHAKE_RESPONDER_READER = Avro::IO::DatumReader.new(HANDSHAKE_REQUEST_SCHEMA)
META_SCHEMA = Avro::Schema.parse('{"type": "map", "values": "bytes"}')
META_WRITER = Avro::IO::DatumWriter.new(META_SCHEMA)
META_READER = Avro::IO::DatumReader.new(META_SCHEMA)
SYSTEM_ERROR_SCHEMA = Avro::Schema.parse('["string"]')
# protocol cache
REMOTE_HASHES = {}
REMOTE_PROTOCOLS = {}
BUFFER_HEADER_LENGTH = 4
BUFFER_SIZE = 8192
# Raised when an error message is sent by an Avro requestor or responder.
class AvroRemoteException < Avro::AvroError; end
class ConnectionClosedException < Avro::AvroError; end
class Requestor
"""Base class for the client side of a protocol interaction."""
attr_reader :local_protocol, :transport
attr_accessor :remote_protocol, :remote_hash, :send_protocol
def initialize(local_protocol, transport)
@local_protocol = local_protocol
@transport = transport
@remote_protocol = nil
@remote_hash = nil
@send_protocol = nil
end
def remote_protocol=(new_remote_protocol)
@remote_protocol = new_remote_protocol
REMOTE_PROTOCOLS[transport.remote_name] = remote_protocol
end
def remote_hash=(new_remote_hash)
@remote_hash = new_remote_hash
REMOTE_HASHES[transport.remote_name] = remote_hash
end
def request(message_name, request_datum)
# Writes a request message and reads a response or error message.
# build handshake and call request
buffer_writer = StringIO.new('', 'w+')
buffer_encoder = Avro::IO::BinaryEncoder.new(buffer_writer)
write_handshake_request(buffer_encoder)
write_call_request(message_name, request_datum, buffer_encoder)
# send the handshake and call request; block until call response
call_request = buffer_writer.string
call_response = transport.transceive(call_request)
# process the handshake and call response
buffer_decoder = Avro::IO::BinaryDecoder.new(StringIO.new(call_response))
if read_handshake_response(buffer_decoder)
read_call_response(message_name, buffer_decoder)
else
request(message_name, request_datum)
end
end
def write_handshake_request(encoder)
local_hash = local_protocol.md5
remote_name = transport.remote_name
remote_hash = REMOTE_HASHES[remote_name]
unless remote_hash
remote_hash = local_hash
self.remote_protocol = local_protocol
end
request_datum = {
'clientHash' => local_hash,
'serverHash' => remote_hash
}
if send_protocol
request_datum['clientProtocol'] = local_protocol.to_s
end
HANDSHAKE_REQUESTOR_WRITER.write(request_datum, encoder)
end
def write_call_request(message_name, request_datum, encoder)
# The format of a call request is:
# * request metadata, a map with values of type bytes
# * the message name, an Avro string, followed by
# * the message parameters. Parameters are serialized according to
# the message's request declaration.
# TODO request metadata (not yet implemented)
request_metadata = {}
META_WRITER.write(request_metadata, encoder)
message = local_protocol.messages[message_name]
unless message
raise AvroError, "Unknown message: #{message_name}"
end
encoder.write_string(message.name)
write_request(message.request, request_datum, encoder)
end
def write_request(request_schema, request_datum, encoder)
datum_writer = Avro::IO::DatumWriter.new(request_schema)
datum_writer.write(request_datum, encoder)
end
def read_handshake_response(decoder)
handshake_response = HANDSHAKE_REQUESTOR_READER.read(decoder)
we_have_matching_schema = false
case handshake_response['match']
when 'BOTH'
self.send_protocol = false
we_have_matching_schema = true
when 'CLIENT'
raise AvroError.new('Handshake failure. match == CLIENT') if send_protocol
self.remote_protocol = Avro::Protocol.parse(handshake_response['serverProtocol'])
self.remote_hash = handshake_response['serverHash']
self.send_protocol = false
we_have_matching_schema = true
when 'NONE'
raise AvroError.new('Handshake failure. match == NONE') if send_protocol
self.remote_protocol = Avro::Protocol.parse(handshake_response['serverProtocol'])
self.remote_hash = handshake_response['serverHash']
self.send_protocol = true
else
raise AvroError.new("Unexpected match: #{match}")
end
return we_have_matching_schema
end
def read_call_response(message_name, decoder)
# The format of a call response is:
# * response metadata, a map with values of type bytes
# * a one-byte error flag boolean, followed by either:
# * if the error flag is false,
# the message response, serialized per the message's response schema.
# * if the error flag is true,
# the error, serialized per the message's error union schema.
response_metadata = META_READER.read(decoder)
# remote response schema
remote_message_schema = remote_protocol.messages[message_name]
raise AvroError.new("Unknown remote message: #{message_name}") unless remote_message_schema
# local response schema
local_message_schema = local_protocol.messages[message_name]
unless local_message_schema
raise AvroError.new("Unknown local message: #{message_name}")
end
# error flag
if !decoder.read_boolean
writers_schema = remote_message_schema.response
readers_schema = local_message_schema.response
read_response(writers_schema, readers_schema, decoder)
else
writers_schema = remote_message_schema.errors || SYSTEM_ERROR_SCHEMA
readers_schema = local_message_schema.errors || SYSTEM_ERROR_SCHEMA
raise read_error(writers_schema, readers_schema, decoder)
end
end
def read_response(writers_schema, readers_schema, decoder)
datum_reader = Avro::IO::DatumReader.new(writers_schema, readers_schema)
datum_reader.read(decoder)
end
def read_error(writers_schema, readers_schema, decoder)
datum_reader = Avro::IO::DatumReader.new(writers_schema, readers_schema)
AvroRemoteError.new(datum_reader.read(decoder))
end
end
# Base class for the server side of a protocol interaction.
class Responder
attr_reader :local_protocol, :local_hash, :protocol_cache
def initialize(local_protocol)
@local_protocol = local_protocol
@local_hash = self.local_protocol.md5
@protocol_cache = {}
protocol_cache[local_hash] = local_protocol
end
# Called by a server to deserialize a request, compute and serialize
# a response or error. Compare to 'handle()' in Thrift.
def respond(call_request)
buffer_decoder = Avro::IO::BinaryDecoder.new(StringIO.new(call_request))
buffer_writer = StringIO.new('', 'w+')
buffer_encoder = Avro::IO::BinaryEncoder.new(buffer_writer)
error = nil
response_metadata = {}
begin
remote_protocol = process_handshake(buffer_decoder, buffer_encoder)
# handshake failure
unless remote_protocol
return buffer_writer.string
end
# read request using remote protocol
request_metadata = META_READER.read(buffer_decoder)
remote_message_name = buffer_decoder.read_string
# get remote and local request schemas so we can do
# schema resolution (one fine day)
remote_message = remote_protocol.messages[remote_message_name]
unless remote_message
raise AvroError.new("Unknown remote message: #{remote_message_name}")
end
local_message = local_protocol.messages[remote_message_name]
unless local_message
raise AvroError.new("Unknown local message: #{remote_message_name}")
end
writers_schema = remote_message.request
readers_schema = local_message.request
request = read_request(writers_schema, readers_schema, buffer_decoder)
# perform server logic
begin
response = call(local_message, request)
rescue AvroRemoteError => e
error = e
rescue Exception => e
error = AvroRemoteError.new(e.to_s)
end
# write response using local protocol
META_WRITER.write(response_metadata, buffer_encoder)
buffer_encoder.write_boolean(!!error)
if error.nil?
writers_schema = local_message.response
write_response(writers_schema, response, buffer_encoder)
else
writers_schema = local_message.errors || SYSTEM_ERROR_SCHEMA
write_error(writers_schema, error, buffer_encoder)
end
rescue Avro::AvroError => e
error = AvroRemoteException.new(e.to_s)
buffer_encoder = Avro::IO::BinaryEncoder.new(StringIO.new)
META_WRITER.write(response_metadata, buffer_encoder)
buffer_encoder.write_boolean(true)
self.write_error(SYSTEM_ERROR_SCHEMA, error, buffer_encoder)
end
buffer_writer.string
end
def process_handshake(decoder, encoder)
handshake_request = HANDSHAKE_RESPONDER_READER.read(decoder)
handshake_response = {}
# determine the remote protocol
client_hash = handshake_request['clientHash']
client_protocol = handshake_request['clientProtocol']
remote_protocol = protocol_cache[client_hash]
if !remote_protocol && client_protocol
remote_protocol = Avro::Protocol.parse(client_protocol)
protocol_cache[client_hash] = remote_protocol
end
# evaluate remote's guess of the local protocol
server_hash = handshake_request['serverHash']
if local_hash == server_hash
if !remote_protocol
handshake_response['match'] = 'NONE'
else
handshake_response['match'] = 'BOTH'
end
else
if !remote_protocol
handshake_response['match'] = 'NONE'
else
handshake_response['match'] = 'CLIENT'
end
end
if handshake_response['match'] != 'BOTH'
handshake_response['serverProtocol'] = local_protocol.to_s
handshake_response['serverHash'] = local_hash
end
HANDSHAKE_RESPONDER_WRITER.write(handshake_response, encoder)
remote_protocol
end
def call(local_message, request)
# Actual work done by server: cf. handler in thrift.
raise NotImplementedError
end
def read_request(writers_schema, readers_schema, decoder)
datum_reader = Avro::IO::DatumReader.new(writers_schema, readers_schema)
datum_reader.read(decoder)
end
def write_response(writers_schema, response_datum, encoder)
datum_writer = Avro::IO::DatumWriter.new(writers_schema)
datum_writer.write(response_datum, encoder)
end
def write_error(writers_schema, error_exception, encoder)
datum_writer = Avro::IO::DatumWriter.new(writers_schema)
datum_writer.write(error_exception.to_s, encoder)
end
end
class SocketTransport
# A simple socket-based Transport implementation.
attr_reader :sock, :remote_name
def initialize(sock)
@sock = sock
end
def transceive(request)
write_framed_message(request)
read_framed_message
end
def read_framed_message
message = []
loop do
buffer = StringIO.new
buffer_length = read_buffer_length
if buffer_length == 0
return message.join
end
while buffer.tell < buffer_length
chunk = sock.read(buffer_length - buffer.tell)
if chunk == ''
raise ConnectionClosedException.new("Socket read 0 bytes.")
end
buffer.write(chunk)
end
message << buffer.string
end
end
def write_framed_message(message)
message_length = message.size
total_bytes_sent = 0
while message_length - total_bytes_sent > 0
if message_length - total_bytes_sent > BUFFER_SIZE:
buffer_length = BUFFER_SIZE
else
buffer_length = message_length - total_bytes_sent
end
write_buffer(message[total_bytes_sent,buffer_length])
total_bytes_sent += buffer_length
end
# A message is always terminated by a zero-length buffer.
write_buffer_length(0)
end
def write_buffer(chunk)
buffer_length = chunk.size
write_buffer_length(buffer_length)
total_bytes_sent = 0
while total_bytes_sent < buffer_length
bytes_sent = self.sock.write(chunk[total_bytes_sent..-1])
if bytes_sent == 0
raise ConnectionClosedException.new("Socket sent 0 bytes.")
end
total_bytes_sent += bytes_sent
end
end
def write_buffer_length(n)
bytes_sent = sock.write([n].pack('N'))
if bytes_sent == 0
raise ConnectionClosedException.new("socket sent 0 bytes")
end
end
def read_buffer_length
read = sock.read(BUFFER_HEADER_LENGTH)
if read == '' || read == nil
raise ConnectionClosedException.new("Socket read 0 bytes.")
end
read.unpack('N')[0]
end
def close
sock.close
end
end
class ConnectionClosedError < StandardError; end
class FramedWriter
attr_reader :writer
def initialize(writer)
@writer = writer
end
def write_framed_message(message)
message_size = message.size
total_bytes_sent = 0
while message_size - total_bytes_sent > 0
if message_size - total_bytes_sent > BUFFER_SIZE
buffer_size = BUFFER_SIZE
else
buffer_size = message_size - total_bytes_sent
end
write_buffer(message[total_bytes_sent, buffer_size])
total_bytes_sent += buffer_size
end
write_buffer_size(0)
end
def to_s; writer.string; end
private
def write_buffer(chunk)
buffer_size = chunk.size
write_buffer_size(buffer_size)
writer << chunk
end
def write_buffer_size(n)
writer.write([n].pack('N'))
end
end
class FramedReader
attr_reader :reader
def initialize(reader)
@reader = reader
end
def read_framed_message
message = []
loop do
buffer = ""
buffer_size = read_buffer_size
return message.join if buffer_size == 0
while buffer.size < buffer_size
chunk = reader.read(buffer_size - buffer.size)
chunk_error?(chunk)
buffer << chunk
end
message << buffer
end
end
private
def read_buffer_size
header = reader.read(BUFFER_HEADER_LENGTH)
chunk_error?(header)
header.unpack('N')[0]
end
def chunk_error?(chunk)
raise ConnectionClosedError.new("Reader read 0 bytes") if chunk == ''
end
end
# Only works for clients. Sigh.
class HTTPTransceiver
attr_reader :remote_name, :host, :port
def initialize(host, port)
@host, @port = host, port
@remote_name = "#{host}:#{port}"
@conn = Net::HTTP.start host, port
end
def transceive(message)
writer = FramedWriter.new(StringIO.new)
writer.write_framed_message(message)
resp = @conn.post('/', writer.to_s, {'Content-Type' => 'avro/binary'})
FramedReader.new(StringIO.new(resp.body)).read_framed_message
end
end
end