blob: e2b6a962d9eebe1d035d0ce8a2b3f1523d8d67e2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.aries.rsa.provider.fastbin.tcp;
import java.io.EOFException;
import java.io.IOException;
import java.lang.reflect.Array;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.aries.rsa.provider.fastbin.api.Dispatched;
import org.apache.aries.rsa.provider.fastbin.api.ObjectSerializationStrategy;
import org.apache.aries.rsa.provider.fastbin.api.Serialization;
import org.apache.aries.rsa.provider.fastbin.api.SerializationStrategy;
import org.apache.aries.rsa.provider.fastbin.io.ServerInvoker;
import org.apache.aries.rsa.provider.fastbin.io.Transport;
import org.apache.aries.rsa.provider.fastbin.io.TransportAcceptListener;
import org.apache.aries.rsa.provider.fastbin.io.TransportListener;
import org.apache.aries.rsa.provider.fastbin.io.TransportServer;
import org.apache.aries.rsa.provider.fastbin.streams.StreamProvider;
import org.apache.aries.rsa.provider.fastbin.streams.StreamProviderImpl;
import org.fusesource.hawtbuf.Buffer;
import org.fusesource.hawtbuf.BufferEditor;
import org.fusesource.hawtbuf.DataByteArrayInputStream;
import org.fusesource.hawtbuf.DataByteArrayOutputStream;
import org.fusesource.hawtbuf.UTF8Buffer;
import org.fusesource.hawtdispatch.DispatchQueue;
import org.osgi.framework.ServiceException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@SuppressWarnings({"rawtypes", "unchecked"})
public class ServerInvokerImpl implements ServerInvoker, Dispatched {
protected static final Logger LOGGER = LoggerFactory.getLogger(ServerInvokerImpl.class);
static private final HashMap<String, Class> PRIMITIVE_TO_CLASS = new HashMap<>(8, 1.0F);
static {
PRIMITIVE_TO_CLASS.put("Z", boolean.class);
PRIMITIVE_TO_CLASS.put("B", byte.class);
PRIMITIVE_TO_CLASS.put("C", char.class);
PRIMITIVE_TO_CLASS.put("S", short.class);
PRIMITIVE_TO_CLASS.put("I", int.class);
PRIMITIVE_TO_CLASS.put("J", long.class);
PRIMITIVE_TO_CLASS.put("F", float.class);
PRIMITIVE_TO_CLASS.put("D", double.class);
}
protected final ExecutorService blockingExecutor = Executors.newFixedThreadPool(8);
protected final DispatchQueue queue;
private final Map<String, SerializationStrategy> serializationStrategies;
protected final TransportServer server;
protected final Map<UTF8Buffer, ServiceFactoryHolder> holders = new HashMap<>();
private StreamProvider streamProvider;
static class MethodData {
private final SerializationStrategy serializationStrategy;
final InvocationStrategy invocationStrategy;
final Method method;
MethodData(InvocationStrategy invocationStrategy, SerializationStrategy serializationStrategy, Method method) {
this.invocationStrategy = invocationStrategy;
this.serializationStrategy = serializationStrategy;
this.method = method;
}
}
class ServiceFactoryHolder {
private final ServiceFactory factory;
private final ClassLoader loader;
private final Class clazz;
private HashMap<Buffer, MethodData> method_cache = new HashMap<>();
public ServiceFactoryHolder(ServiceFactory factory, ClassLoader loader) {
this.factory = factory;
this.loader = loader;
Object o = factory.get();
clazz = o.getClass();
factory.unget();
}
private MethodData getMethodData(Buffer data) throws IOException, NoSuchMethodException, ClassNotFoundException {
MethodData rc = method_cache.get(data);
if( rc == null ) {
String[] parts = data.utf8().toString().split(",");
String name = parts[0];
Class params[] = new Class[parts.length-1];
for( int i=0; i < params.length; i++) {
params[i] = decodeClass(parts[i+1]);
}
Method method = clazz.getMethod(name, params);
Serialization annotation = method.getAnnotation(Serialization.class);
SerializationStrategy serializationStrategy;
if( annotation!=null ) {
serializationStrategy = serializationStrategies.get(annotation.value());
if( serializationStrategy==null ) {
throw new RuntimeException("Could not find the serialization strategy named: "+annotation.value());
}
} else {
serializationStrategy = ObjectSerializationStrategy.INSTANCE;
}
final InvocationStrategy invocationStrategy = InvocationType.forMethod(method);
rc = new MethodData(invocationStrategy, serializationStrategy, method);
method_cache.put(data, rc);
}
return rc;
}
private Class<?> decodeClass(String s) throws ClassNotFoundException {
if( s.startsWith("[")) {
Class<?> nested = decodeClass(s.substring(1));
return Array.newInstance(nested, 0).getClass();
}
String c = s.substring(0, 1);
if( c.equals("L") ) {
return loader.loadClass(s.substring(1));
} else {
return PRIMITIVE_TO_CLASS.get(c);
}
}
}
public ServerInvokerImpl(String address, DispatchQueue queue, Map<String, SerializationStrategy> serializationStrategies) throws Exception {
this.queue = queue;
this.serializationStrategies = serializationStrategies;
this.server = new TcpTransportFactory().bind(address);
this.server.setDispatchQueue(queue);
this.server.setAcceptListener(new InvokerAcceptListener());
}
public InetSocketAddress getSocketAddress() {
return this.server.getSocketAddress();
}
public DispatchQueue queue() {
return queue;
}
public String getConnectAddress() {
return this.server.getConnectAddress();
}
@Override
public StreamProvider getStreamProvider() {
return streamProvider;
}
public void registerService(final String id, final ServiceFactory service, final ClassLoader classLoader) {
queue().execute(new Runnable() {
public void run() {
holders.put(new UTF8Buffer(id), new ServiceFactoryHolder(service, classLoader));
}
});
}
public void unregisterService(final String id) {
queue().execute(new Runnable() {
public void run() {
holders.remove(new UTF8Buffer(id));
}
});
}
public void start() throws Exception {
start(null);
}
public void start(Runnable onComplete) throws Exception {
registerStreamProvider();
this.server.start(onComplete);
}
private void registerStreamProvider() {
streamProvider = new StreamProviderImpl();
registerService(StreamProvider.STREAM_PROVIDER_SERVICE_NAME, new ServerInvoker.ServiceFactory() {
@Override
public Object get() {
return streamProvider;
}
@Override
public void unget(){
// nothing to do
}
}, getClass().getClassLoader());
}
public void stop() {
stop(null);
}
public void stop(final Runnable onComplete) {
this.server.stop(new Runnable() {
public void run() {
blockingExecutor.shutdown();
if (onComplete != null) {
onComplete.run();
}
}
});
}
protected void onCommand(final Transport transport, Object data) {
try {
final DataByteArrayInputStream bais = new DataByteArrayInputStream((Buffer) data);
final int size = bais.readInt();
final long correlation = bais.readVarLong();
// Use UTF8Buffer instead of string to avoid encoding/decoding UTF-8 strings
// for every request.
final UTF8Buffer service = readBuffer(bais).utf8();
final Buffer encoded_method = readBuffer(bais);
final ServiceFactoryHolder holder = holders.get(service);
Runnable task = null;
if(holder==null) {
String message = "The requested service {"+service+"} is not available";
LOGGER.warn(message);
task = new SendTask(bais, correlation, transport, message);
}
final Object svc = holder==null ? null : holder.factory.get();
if(holder!=null) {
try {
final MethodData methodData = holder.getMethodData(encoded_method);
task = new SendTask(svc, bais, holder, correlation, methodData, transport);
}
catch (ReflectiveOperationException reflectionEx) {
final String methodName = encoded_method.utf8().toString();
String message = "The requested method {"+methodName+"} is not available";
LOGGER.warn(message);
task = new SendTask(bais, correlation, transport, message);
}
}
Executor executor;
if( svc instanceof Dispatched ) {
executor = ((Dispatched)svc).queue();
} else {
executor = blockingExecutor;
}
executor.execute(task);
} catch (Exception e) {
LOGGER.info("Error while reading request", e);
}
}
private Buffer readBuffer(DataByteArrayInputStream bais) throws IOException {
byte b[] = new byte[bais.readVarInt()];
bais.readFully(b);
return new Buffer(b);
}
class InvokerAcceptListener implements TransportAcceptListener {
public void onAccept(TransportServer transportServer, TcpTransport transport) {
transport.setProtocolCodec(new LengthPrefixedCodec());
transport.setDispatchQueue(queue());
transport.setTransportListener(new InvokerTransportListener());
transport.start();
}
public void onAcceptError(TransportServer transportServer, Exception error) {
LOGGER.info("Error accepting incoming connection", error);
}
}
class InvokerTransportListener implements TransportListener {
public void onTransportCommand(Transport transport, Object command) {
ServerInvokerImpl.this.onCommand(transport, command);
}
public void onRefill(Transport transport) {
}
public void onTransportFailure(Transport transport, IOException error) {
if (!transport.isDisposed() && !(error instanceof EOFException)) {
LOGGER.info("Transport failure", error);
}
}
public void onTransportConnected(Transport transport) {
transport.resumeRead();
}
public void onTransportDisconnected(Transport transport) {
}
}
private final class SendTask implements Runnable {
private Object svc;
private DataByteArrayInputStream bais;
private ServiceFactoryHolder holder;
private long correlation;
private MethodData methodData;
private Transport transport;
private SendTask(Object svc, DataByteArrayInputStream bais, ServiceFactoryHolder holder, long correlation, MethodData methodData, Transport transport) {
this.svc = svc;
this.bais = bais;
this.holder = holder;
this.correlation = correlation;
this.methodData = methodData;
this.transport = transport;
}
private SendTask(DataByteArrayInputStream bais, long correlation, Transport transport, String errorMessage) {
this(new ServiceException(errorMessage), bais, null, correlation, new MethodData(new BlockingInvocationStrategy(), ObjectSerializationStrategy.INSTANCE, null), transport);
}
public void run() {
final DataByteArrayOutputStream baos = new DataByteArrayOutputStream();
try {
baos.writeInt(0); // make space for the size field.
baos.writeVarLong(correlation);
} catch (IOException e) { // should not happen
LOGGER.error("Failed to write to buffer", e);
throw new RuntimeException(e);
}
// Lets decode the remaining args on the target's executor
// to take cpu load off the
ClassLoader loader = holder==null ? getClass().getClassLoader() : holder.loader;
methodData.invocationStrategy.service(methodData.serializationStrategy, loader, methodData.method, svc, bais, baos, new Runnable() {
public void run() {
if(holder!=null)
holder.factory.unget();
final Buffer command = baos.toBuffer();
// Update the size field.
BufferEditor editor = command.buffer().bigEndianEditor();
editor.writeInt(command.length);
queue().execute(new Runnable() {
public void run() {
transport.offer(command);
}
});
}
});
}
}
}