blob: 0a44b138da8bd68f1a4e09ee8df3ac153c28c12a [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.avro.ipc;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.avro.AvroRuntimeException;
import org.apache.avro.Protocol;
import org.apache.avro.Schema;
import org.apache.avro.Protocol.Message;
import org.apache.avro.generic.GenericDatumReader;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.io.DecoderFactory;
import org.apache.avro.io.BinaryDecoder;
import org.apache.avro.io.Decoder;
import org.apache.avro.io.Encoder;
import org.apache.avro.io.BinaryEncoder;
import org.apache.avro.io.EncoderFactory;
import org.apache.avro.specific.SpecificDatumReader;
import org.apache.avro.specific.SpecificDatumWriter;
import org.apache.avro.util.ByteBufferInputStream;
import org.apache.avro.util.ByteBufferOutputStream;
import org.apache.avro.util.Utf8;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** Base class for the client side of a protocol interaction. */
public abstract class Requestor {
private static final Logger LOG = LoggerFactory.getLogger(Requestor.class);
private static final Schema META =
Schema.createMap(Schema.create(Schema.Type.BYTES));
private static final GenericDatumReader<Map<CharSequence,ByteBuffer>>
META_READER = new GenericDatumReader<Map<CharSequence,ByteBuffer>>(META);
private static final GenericDatumWriter<Map<CharSequence,ByteBuffer>>
META_WRITER = new GenericDatumWriter<Map<CharSequence,ByteBuffer>>(META);
private Protocol local;
private Protocol remote;
private boolean sendLocalText;
private Transceiver transceiver;
protected List<RPCPlugin> rpcMetaPlugins;
public Protocol getLocal() { return local; }
public Transceiver getTransceiver() { return transceiver; }
protected Requestor(Protocol local, Transceiver transceiver)
throws IOException {
this.local = local;
this.transceiver = transceiver;
this.rpcMetaPlugins =
Collections.synchronizedList(new ArrayList<RPCPlugin>());
}
/**
* Adds a new plugin to manipulate RPC metadata. Plugins
* are executed in the order that they are added.
* @param plugin a plugin that will manipulate RPC metadata
*/
public void addRPCPlugin(RPCPlugin plugin) {
rpcMetaPlugins.add(plugin);
}
private static final EncoderFactory ENCODER_FACTORY = new EncoderFactory();
private BinaryEncoder encoder =
ENCODER_FACTORY.binaryEncoder(new ByteBufferOutputStream(), null);
/** Writes a request message and reads a response or error message. */
public synchronized Object request(String messageName, Object request)
throws Exception {
Transceiver t = getTransceiver();
BinaryDecoder in = null;
Message m;
RPCContext context = new RPCContext();
do {
ByteBufferOutputStream bbo = new ByteBufferOutputStream();
//safe to use encoder because this is synchronized
BinaryEncoder out = ENCODER_FACTORY.binaryEncoder(bbo, encoder);
// use local protocol to write request
m = getLocal().getMessages().get(messageName);
if (m == null)
throw new AvroRuntimeException("Not a local message: "+messageName);
context.setMessage(m);
writeRequest(m.getRequest(), request, out); // write request payload
out.flush();
List<ByteBuffer> payload = bbo.getBufferList();
writeHandshake(out); // prepend handshake if needed
context.setRequestPayload(payload);
for (RPCPlugin plugin : rpcMetaPlugins) {
plugin.clientSendRequest(context); // get meta-data from plugins
}
META_WRITER.write(context.requestCallMeta(), out);
out.writeString(m.getName()); // write message name
out.flush();
bbo.append(payload);
List<ByteBuffer> requestBytes = bbo.getBufferList();
if (m.isOneWay() && t.isConnected()) { // send one-way message
t.writeBuffers(requestBytes);
return null;
} else { // two-way message
List<ByteBuffer> response = t.transceive(requestBytes);
ByteBufferInputStream bbi = new ByteBufferInputStream(response);
in = DecoderFactory.get().binaryDecoder(bbi, in);
}
} while (!readHandshake(in));
// use remote protocol to read response
Message rm = remote.getMessages().get(messageName);
if (rm == null)
throw new AvroRuntimeException("Not a remote message: "+messageName);
if ((m.isOneWay() != rm.isOneWay()) && t.isConnected())
throw new AvroRuntimeException("Not both one-way messages: "+messageName);
if (m.isOneWay() && t.isConnected()) return null; // one-way w/ handshake
context.setResponseCallMeta(META_READER.read(null, in));
if (!in.readBoolean()) { // no error
Object response = readResponse(rm.getResponse(), in);
context.setResponse(response);
for (RPCPlugin plugin : rpcMetaPlugins) {
plugin.clientReceiveResponse(context);
}
return response;
} else {
Exception error = readError(rm.getErrors(), in);
context.setError(error);
for (RPCPlugin plugin : rpcMetaPlugins) {
plugin.clientReceiveResponse(context);
}
throw error;
}
}
private static final Map<String,MD5> REMOTE_HASHES =
Collections.synchronizedMap(new HashMap<String,MD5>());
private static final Map<MD5,Protocol> REMOTE_PROTOCOLS =
Collections.synchronizedMap(new HashMap<MD5,Protocol>());
private static final SpecificDatumWriter<HandshakeRequest> HANDSHAKE_WRITER =
new SpecificDatumWriter<HandshakeRequest>(HandshakeRequest.class);
private static final SpecificDatumReader<HandshakeResponse> HANDSHAKE_READER =
new SpecificDatumReader<HandshakeResponse>(HandshakeResponse.class);
private void writeHandshake(Encoder out) throws IOException {
if (getTransceiver().isConnected()) return;
MD5 localHash = new MD5();
localHash.bytes(local.getMD5());
String remoteName = transceiver.getRemoteName();
MD5 remoteHash = REMOTE_HASHES.get(remoteName);
remote = REMOTE_PROTOCOLS.get(remoteHash);
if (remoteHash == null) { // guess remote is local
remoteHash = localHash;
remote = local;
}
HandshakeRequest handshake = new HandshakeRequest();
handshake.clientHash = localHash;
handshake.serverHash = remoteHash;
if (sendLocalText)
handshake.clientProtocol = new Utf8(local.toString());
RPCContext context = new RPCContext();
context.setHandshakeRequest(handshake);
for (RPCPlugin plugin : rpcMetaPlugins) {
plugin.clientStartConnect(context);
}
handshake.meta = context.requestHandshakeMeta();
HANDSHAKE_WRITER.write(handshake, out);
}
private boolean readHandshake(Decoder in) throws IOException {
if (getTransceiver().isConnected()) return true;
boolean established = false;
HandshakeResponse handshake = HANDSHAKE_READER.read(null, in);
switch (handshake.match) {
case BOTH:
established = true;
sendLocalText = false;
break;
case CLIENT:
LOG.debug("Handshake match = CLIENT");
setRemote(handshake);
established = true;
sendLocalText = false;
break;
case NONE:
LOG.debug("Handshake match = NONE");
setRemote(handshake);
sendLocalText = true;
break;
default:
throw new AvroRuntimeException("Unexpected match: "+handshake.match);
}
RPCContext context = new RPCContext();
context.setHandshakeResponse(handshake);
for (RPCPlugin plugin : rpcMetaPlugins) {
plugin.clientFinishConnect(context);
}
if (established)
getTransceiver().setRemote(remote);
return established;
}
private void setRemote(HandshakeResponse handshake) {
remote = Protocol.parse(handshake.serverProtocol.toString());
MD5 remoteHash = (MD5)handshake.serverHash;
REMOTE_HASHES.put(transceiver.getRemoteName(), remoteHash);
if (!REMOTE_PROTOCOLS.containsKey(remoteHash))
REMOTE_PROTOCOLS.put(remoteHash, remote);
}
/** Return the remote protocol. Force a handshake if required. */
public synchronized Protocol getRemote() throws IOException {
if (remote != null) return remote; // already have it
MD5 remoteHash = REMOTE_HASHES.get(transceiver.getRemoteName());
remote = REMOTE_PROTOCOLS.get(remoteHash);
if (remote != null) return remote; // already cached
// force handshake
ByteBufferOutputStream bbo = new ByteBufferOutputStream();
// direct because the payload is tiny.
Encoder out = ENCODER_FACTORY.directBinaryEncoder(bbo, null);
writeHandshake(out);
out.writeInt(0); // empty metadata
out.writeString(""); // bogus message name
List<ByteBuffer> response =
getTransceiver().transceive(bbo.getBufferList());
ByteBufferInputStream bbi = new ByteBufferInputStream(response);
BinaryDecoder in =
DecoderFactory.get().binaryDecoder(bbi, null);
readHandshake(in);
return this.remote;
}
/** Writes a request message. */
public abstract void writeRequest(Schema schema, Object request,
Encoder out) throws IOException;
/** Reads a response message. */
public abstract Object readResponse(Schema schema, Decoder in)
throws IOException;
/** Reads an error message. */
public abstract Exception readError(Schema schema, Decoder in)
throws IOException;
}