| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| require "net/http" |
| |
| module Avro::IPC |
| |
| class AvroRemoteError < Avro::AvroError; end |
| |
| HANDSHAKE_REQUEST_SCHEMA = Avro::Schema.parse <<-JSON |
| { |
| "type": "record", |
| "name": "HandshakeRequest", "namespace":"org.apache.avro.ipc", |
| "fields": [ |
| {"name": "clientHash", |
| "type": {"type": "fixed", "name": "MD5", "size": 16}}, |
| {"name": "clientProtocol", "type": ["null", "string"]}, |
| {"name": "serverHash", "type": "MD5"}, |
| {"name": "meta", "type": ["null", {"type": "map", "values": "bytes"}]} |
| ] |
| } |
| JSON |
| |
| HANDSHAKE_RESPONSE_SCHEMA = Avro::Schema.parse <<-JSON |
| { |
| "type": "record", |
| "name": "HandshakeResponse", "namespace": "org.apache.avro.ipc", |
| "fields": [ |
| {"name": "match", |
| "type": {"type": "enum", "name": "HandshakeMatch", |
| "symbols": ["BOTH", "CLIENT", "NONE"]}}, |
| {"name": "serverProtocol", "type": ["null", "string"]}, |
| {"name": "serverHash", |
| "type": ["null", {"type": "fixed", "name": "MD5", "size": 16}]}, |
| {"name": "meta", |
| "type": ["null", {"type": "map", "values": "bytes"}]} |
| ] |
| } |
| JSON |
| |
| HANDSHAKE_REQUESTOR_WRITER = Avro::IO::DatumWriter.new(HANDSHAKE_REQUEST_SCHEMA) |
| HANDSHAKE_REQUESTOR_READER = Avro::IO::DatumReader.new(HANDSHAKE_RESPONSE_SCHEMA) |
| HANDSHAKE_RESPONDER_WRITER = Avro::IO::DatumWriter.new(HANDSHAKE_RESPONSE_SCHEMA) |
| HANDSHAKE_RESPONDER_READER = Avro::IO::DatumReader.new(HANDSHAKE_REQUEST_SCHEMA) |
| |
| META_SCHEMA = Avro::Schema.parse('{"type": "map", "values": "bytes"}') |
| META_WRITER = Avro::IO::DatumWriter.new(META_SCHEMA) |
| META_READER = Avro::IO::DatumReader.new(META_SCHEMA) |
| |
| SYSTEM_ERROR_SCHEMA = Avro::Schema.parse('["string"]') |
| |
| # protocol cache |
| REMOTE_HASHES = {} |
| REMOTE_PROTOCOLS = {} |
| |
| BUFFER_HEADER_LENGTH = 4 |
| BUFFER_SIZE = 8192 |
| |
| # Raised when an error message is sent by an Avro requestor or responder. |
| class AvroRemoteException < Avro::AvroError; end |
| |
| class ConnectionClosedException < Avro::AvroError; end |
| |
| class Requestor |
| """Base class for the client side of a protocol interaction.""" |
| attr_reader :local_protocol, :transport |
| attr_accessor :remote_protocol, :remote_hash, :send_protocol |
| |
| def initialize(local_protocol, transport) |
| @local_protocol = local_protocol |
| @transport = transport |
| @remote_protocol = nil |
| @remote_hash = nil |
| @send_protocol = nil |
| end |
| |
| def remote_protocol=(new_remote_protocol) |
| @remote_protocol = new_remote_protocol |
| REMOTE_PROTOCOLS[transport.remote_name] = remote_protocol |
| end |
| |
| def remote_hash=(new_remote_hash) |
| @remote_hash = new_remote_hash |
| REMOTE_HASHES[transport.remote_name] = remote_hash |
| end |
| |
| def request(message_name, request_datum) |
| # Writes a request message and reads a response or error message. |
| # build handshake and call request |
| buffer_writer = StringIO.new('', 'w+') |
| buffer_encoder = Avro::IO::BinaryEncoder.new(buffer_writer) |
| write_handshake_request(buffer_encoder) |
| write_call_request(message_name, request_datum, buffer_encoder) |
| |
| # send the handshake and call request; block until call response |
| call_request = buffer_writer.string |
| call_response = transport.transceive(call_request) |
| |
| # process the handshake and call response |
| buffer_decoder = Avro::IO::BinaryDecoder.new(StringIO.new(call_response)) |
| if read_handshake_response(buffer_decoder) |
| read_call_response(message_name, buffer_decoder) |
| else |
| request(message_name, request_datum) |
| end |
| end |
| |
| def write_handshake_request(encoder) |
| local_hash = local_protocol.md5 |
| remote_name = transport.remote_name |
| remote_hash = REMOTE_HASHES[remote_name] |
| unless remote_hash |
| remote_hash = local_hash |
| self.remote_protocol = local_protocol |
| end |
| request_datum = { |
| 'clientHash' => local_hash, |
| 'serverHash' => remote_hash |
| } |
| if send_protocol |
| request_datum['clientProtocol'] = local_protocol.to_s |
| end |
| HANDSHAKE_REQUESTOR_WRITER.write(request_datum, encoder) |
| end |
| |
| def write_call_request(message_name, request_datum, encoder) |
| # The format of a call request is: |
| # * request metadata, a map with values of type bytes |
| # * the message name, an Avro string, followed by |
| # * the message parameters. Parameters are serialized according to |
| # the message's request declaration. |
| |
| # TODO request metadata (not yet implemented) |
| request_metadata = {} |
| META_WRITER.write(request_metadata, encoder) |
| |
| message = local_protocol.messages[message_name] |
| unless message |
| raise AvroError, "Unknown message: #{message_name}" |
| end |
| encoder.write_string(message.name) |
| |
| write_request(message.request, request_datum, encoder) |
| end |
| |
| def write_request(request_schema, request_datum, encoder) |
| datum_writer = Avro::IO::DatumWriter.new(request_schema) |
| datum_writer.write(request_datum, encoder) |
| end |
| |
| def read_handshake_response(decoder) |
| handshake_response = HANDSHAKE_REQUESTOR_READER.read(decoder) |
| we_have_matching_schema = false |
| |
| case handshake_response['match'] |
| when 'BOTH' |
| self.send_protocol = false |
| we_have_matching_schema = true |
| when 'CLIENT' |
| raise AvroError.new('Handshake failure. match == CLIENT') if send_protocol |
| self.remote_protocol = Avro::Protocol.parse(handshake_response['serverProtocol']) |
| self.remote_hash = handshake_response['serverHash'] |
| self.send_protocol = false |
| we_have_matching_schema = true |
| when 'NONE' |
| raise AvroError.new('Handshake failure. match == NONE') if send_protocol |
| self.remote_protocol = Avro::Protocol.parse(handshake_response['serverProtocol']) |
| self.remote_hash = handshake_response['serverHash'] |
| self.send_protocol = true |
| else |
| raise AvroError.new("Unexpected match: #{match}") |
| end |
| |
| return we_have_matching_schema |
| end |
| |
| def read_call_response(message_name, decoder) |
| # The format of a call response is: |
| # * response metadata, a map with values of type bytes |
| # * a one-byte error flag boolean, followed by either: |
| # * if the error flag is false, |
| # the message response, serialized per the message's response schema. |
| # * if the error flag is true, |
| # the error, serialized per the message's error union schema. |
| response_metadata = META_READER.read(decoder) |
| |
| # remote response schema |
| remote_message_schema = remote_protocol.messages[message_name] |
| raise AvroError.new("Unknown remote message: #{message_name}") unless remote_message_schema |
| |
| # local response schema |
| local_message_schema = local_protocol.messages[message_name] |
| unless local_message_schema |
| raise AvroError.new("Unknown local message: #{message_name}") |
| end |
| |
| # error flag |
| if !decoder.read_boolean |
| writers_schema = remote_message_schema.response |
| readers_schema = local_message_schema.response |
| read_response(writers_schema, readers_schema, decoder) |
| else |
| writers_schema = remote_message_schema.errors || SYSTEM_ERROR_SCHEMA |
| readers_schema = local_message_schema.errors || SYSTEM_ERROR_SCHEMA |
| raise read_error(writers_schema, readers_schema, decoder) |
| end |
| end |
| |
| def read_response(writers_schema, readers_schema, decoder) |
| datum_reader = Avro::IO::DatumReader.new(writers_schema, readers_schema) |
| datum_reader.read(decoder) |
| end |
| |
| def read_error(writers_schema, readers_schema, decoder) |
| datum_reader = Avro::IO::DatumReader.new(writers_schema, readers_schema) |
| AvroRemoteError.new(datum_reader.read(decoder)) |
| end |
| end |
| |
| # Base class for the server side of a protocol interaction. |
| class Responder |
| attr_reader :local_protocol, :local_hash, :protocol_cache |
| def initialize(local_protocol) |
| @local_protocol = local_protocol |
| @local_hash = self.local_protocol.md5 |
| @protocol_cache = {} |
| protocol_cache[local_hash] = local_protocol |
| end |
| |
| # Called by a server to deserialize a request, compute and serialize |
| # a response or error. Compare to 'handle()' in Thrift. |
| def respond(call_request) |
| buffer_decoder = Avro::IO::BinaryDecoder.new(StringIO.new(call_request)) |
| buffer_writer = StringIO.new('', 'w+') |
| buffer_encoder = Avro::IO::BinaryEncoder.new(buffer_writer) |
| error = nil |
| response_metadata = {} |
| |
| begin |
| remote_protocol = process_handshake(buffer_decoder, buffer_encoder) |
| # handshake failure |
| unless remote_protocol |
| return buffer_writer.string |
| end |
| |
| # read request using remote protocol |
| request_metadata = META_READER.read(buffer_decoder) |
| remote_message_name = buffer_decoder.read_string |
| |
| # get remote and local request schemas so we can do |
| # schema resolution (one fine day) |
| remote_message = remote_protocol.messages[remote_message_name] |
| unless remote_message |
| raise AvroError.new("Unknown remote message: #{remote_message_name}") |
| end |
| local_message = local_protocol.messages[remote_message_name] |
| unless local_message |
| raise AvroError.new("Unknown local message: #{remote_message_name}") |
| end |
| writers_schema = remote_message.request |
| readers_schema = local_message.request |
| request = read_request(writers_schema, readers_schema, buffer_decoder) |
| # perform server logic |
| begin |
| response = call(local_message, request) |
| rescue AvroRemoteError => e |
| error = e |
| rescue Exception => e |
| error = AvroRemoteError.new(e.to_s) |
| end |
| |
| # write response using local protocol |
| META_WRITER.write(response_metadata, buffer_encoder) |
| buffer_encoder.write_boolean(!!error) |
| if error.nil? |
| writers_schema = local_message.response |
| write_response(writers_schema, response, buffer_encoder) |
| else |
| writers_schema = local_message.errors || SYSTEM_ERROR_SCHEMA |
| write_error(writers_schema, error, buffer_encoder) |
| end |
| rescue Avro::AvroError => e |
| error = AvroRemoteException.new(e.to_s) |
| buffer_encoder = Avro::IO::BinaryEncoder.new(StringIO.new) |
| META_WRITER.write(response_metadata, buffer_encoder) |
| buffer_encoder.write_boolean(true) |
| self.write_error(SYSTEM_ERROR_SCHEMA, error, buffer_encoder) |
| end |
| buffer_writer.string |
| end |
| |
| def process_handshake(decoder, encoder) |
| handshake_request = HANDSHAKE_RESPONDER_READER.read(decoder) |
| handshake_response = {} |
| |
| # determine the remote protocol |
| client_hash = handshake_request['clientHash'] |
| client_protocol = handshake_request['clientProtocol'] |
| remote_protocol = protocol_cache[client_hash] |
| |
| if !remote_protocol && client_protocol |
| remote_protocol = Avro::Protocol.parse(client_protocol) |
| protocol_cache[client_hash] = remote_protocol |
| end |
| |
| # evaluate remote's guess of the local protocol |
| server_hash = handshake_request['serverHash'] |
| if local_hash == server_hash |
| if !remote_protocol |
| handshake_response['match'] = 'NONE' |
| else |
| handshake_response['match'] = 'BOTH' |
| end |
| else |
| if !remote_protocol |
| handshake_response['match'] = 'NONE' |
| else |
| handshake_response['match'] = 'CLIENT' |
| end |
| end |
| |
| if handshake_response['match'] != 'BOTH' |
| handshake_response['serverProtocol'] = local_protocol.to_s |
| handshake_response['serverHash'] = local_hash |
| end |
| |
| HANDSHAKE_RESPONDER_WRITER.write(handshake_response, encoder) |
| remote_protocol |
| end |
| |
| def call(local_message, request) |
| # Actual work done by server: cf. handler in thrift. |
| raise NotImplementedError |
| end |
| |
| def read_request(writers_schema, readers_schema, decoder) |
| datum_reader = Avro::IO::DatumReader.new(writers_schema, readers_schema) |
| datum_reader.read(decoder) |
| end |
| |
| def write_response(writers_schema, response_datum, encoder) |
| datum_writer = Avro::IO::DatumWriter.new(writers_schema) |
| datum_writer.write(response_datum, encoder) |
| end |
| |
| def write_error(writers_schema, error_exception, encoder) |
| datum_writer = Avro::IO::DatumWriter.new(writers_schema) |
| datum_writer.write(error_exception.to_s, encoder) |
| end |
| end |
| |
| class SocketTransport |
| # A simple socket-based Transport implementation. |
| |
| attr_reader :sock, :remote_name |
| |
| def initialize(sock) |
| @sock = sock |
| end |
| |
| def transceive(request) |
| write_framed_message(request) |
| read_framed_message |
| end |
| |
| def read_framed_message |
| message = [] |
| loop do |
| buffer = StringIO.new |
| buffer_length = read_buffer_length |
| if buffer_length == 0 |
| return message.join |
| end |
| while buffer.tell < buffer_length |
| chunk = sock.read(buffer_length - buffer.tell) |
| if chunk == '' |
| raise ConnectionClosedException.new("Socket read 0 bytes.") |
| end |
| buffer.write(chunk) |
| end |
| message << buffer.string |
| end |
| end |
| |
| def write_framed_message(message) |
| message_length = message.size |
| total_bytes_sent = 0 |
| while message_length - total_bytes_sent > 0 |
| if message_length - total_bytes_sent > BUFFER_SIZE: |
| buffer_length = BUFFER_SIZE |
| else |
| buffer_length = message_length - total_bytes_sent |
| end |
| write_buffer(message[total_bytes_sent,buffer_length]) |
| total_bytes_sent += buffer_length |
| end |
| # A message is always terminated by a zero-length buffer. |
| write_buffer_length(0) |
| end |
| |
| def write_buffer(chunk) |
| buffer_length = chunk.size |
| write_buffer_length(buffer_length) |
| total_bytes_sent = 0 |
| while total_bytes_sent < buffer_length |
| bytes_sent = self.sock.write(chunk[total_bytes_sent..-1]) |
| if bytes_sent == 0 |
| raise ConnectionClosedException.new("Socket sent 0 bytes.") |
| end |
| total_bytes_sent += bytes_sent |
| end |
| end |
| |
| def write_buffer_length(n) |
| bytes_sent = sock.write([n].pack('N')) |
| if bytes_sent == 0 |
| raise ConnectionClosedException.new("socket sent 0 bytes") |
| end |
| end |
| |
| def read_buffer_length |
| read = sock.read(BUFFER_HEADER_LENGTH) |
| if read == '' || read == nil |
| raise ConnectionClosedException.new("Socket read 0 bytes.") |
| end |
| read.unpack('N')[0] |
| end |
| |
| def close |
| sock.close |
| end |
| end |
| |
| class ConnectionClosedError < StandardError; end |
| |
| class FramedWriter |
| attr_reader :writer |
| def initialize(writer) |
| @writer = writer |
| end |
| |
| def write_framed_message(message) |
| message_size = message.size |
| total_bytes_sent = 0 |
| while message_size - total_bytes_sent > 0 |
| if message_size - total_bytes_sent > BUFFER_SIZE |
| buffer_size = BUFFER_SIZE |
| else |
| buffer_size = message_size - total_bytes_sent |
| end |
| write_buffer(message[total_bytes_sent, buffer_size]) |
| total_bytes_sent += buffer_size |
| end |
| write_buffer_size(0) |
| end |
| |
| def to_s; writer.string; end |
| |
| private |
| def write_buffer(chunk) |
| buffer_size = chunk.size |
| write_buffer_size(buffer_size) |
| writer << chunk |
| end |
| |
| def write_buffer_size(n) |
| writer.write([n].pack('N')) |
| end |
| end |
| |
| class FramedReader |
| attr_reader :reader |
| |
| def initialize(reader) |
| @reader = reader |
| end |
| |
| def read_framed_message |
| message = [] |
| loop do |
| buffer = "" |
| buffer_size = read_buffer_size |
| |
| return message.join if buffer_size == 0 |
| |
| while buffer.size < buffer_size |
| chunk = reader.read(buffer_size - buffer.size) |
| chunk_error?(chunk) |
| buffer << chunk |
| end |
| message << buffer |
| end |
| end |
| |
| private |
| def read_buffer_size |
| header = reader.read(BUFFER_HEADER_LENGTH) |
| chunk_error?(header) |
| header.unpack('N')[0] |
| end |
| |
| def chunk_error?(chunk) |
| raise ConnectionClosedError.new("Reader read 0 bytes") if chunk == '' |
| end |
| end |
| |
| # Only works for clients. Sigh. |
| class HTTPTransceiver |
| attr_reader :remote_name, :host, :port |
| def initialize(host, port) |
| @host, @port = host, port |
| @remote_name = "#{host}:#{port}" |
| @conn = Net::HTTP.start host, port |
| end |
| |
| def transceive(message) |
| writer = FramedWriter.new(StringIO.new) |
| writer.write_framed_message(message) |
| resp = @conn.post('/', writer.to_s, {'Content-Type' => 'avro/binary'}) |
| FramedReader.new(StringIO.new(resp.body)).read_framed_message |
| end |
| end |
| end |