| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package org.apache.aries.rsa.provider.fastbin.tcp; |
| |
| import java.io.IOException; |
| import java.net.InetAddress; |
| import java.net.InetSocketAddress; |
| import java.net.URI; |
| import java.net.UnknownHostException; |
| import java.nio.ByteBuffer; |
| import java.nio.channels.ReadableByteChannel; |
| import java.nio.channels.SelectionKey; |
| import java.nio.channels.SocketChannel; |
| import java.nio.channels.WritableByteChannel; |
| import java.util.LinkedList; |
| import java.util.Map; |
| import java.util.concurrent.TimeUnit; |
| |
| import org.apache.aries.rsa.provider.fastbin.io.ProtocolCodec; |
| import org.apache.aries.rsa.provider.fastbin.io.Transport; |
| import org.apache.aries.rsa.provider.fastbin.io.TransportListener; |
| import org.fusesource.hawtdispatch.Dispatch; |
| import org.fusesource.hawtdispatch.DispatchQueue; |
| import org.fusesource.hawtdispatch.DispatchSource; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| public class TcpTransport implements Transport { |
| |
| private static final Logger LOG = LoggerFactory.getLogger(TcpTransport.class); |
| |
| protected State _serviceState = new CREATED(); |
| |
| protected Map<String, Object> socketOptions; |
| |
| protected URI remoteLocation; |
| protected URI localLocation; |
| protected TransportListener listener; |
| protected String remoteAddress; |
| protected ProtocolCodec codec; |
| |
| protected SocketChannel channel; |
| |
| protected SocketState socketState = new DISCONNECTED(); |
| |
| protected DispatchQueue dispatchQueue; |
| private DispatchSource readSource; |
| private DispatchSource writeSource; |
| |
| protected boolean useLocalHost = true; |
| |
| int max_read_rate; |
| int max_write_rate; |
| protected RateLimitingChannel rateLimitingChannel; |
| |
| boolean drained = true; |
| |
| private final Runnable CANCEL_HANDLER = new Runnable() { |
| public void run() { |
| socketState.onCanceled(); |
| } |
| }; |
| |
| final public void start() { |
| start(null); |
| } |
| |
| final public void stop() { |
| stop(null); |
| } |
| |
| final public void start(final Runnable onCompleted) { |
| queue().execute(new Runnable() { |
| public void run() { |
| if (_serviceState.isCreated() || _serviceState.isStopped()) { |
| final STARTING state = new STARTING(); |
| state.add(onCompleted); |
| _serviceState = state; |
| _start(new Runnable() { |
| public void run() { |
| _serviceState = new STARTED(); |
| state.done(); |
| } |
| }); |
| } else if (_serviceState.isStarting()) { |
| _serviceState.add(onCompleted); |
| } else if (_serviceState.isStarted()) { |
| if (onCompleted != null) { |
| onCompleted.run(); |
| } |
| } else { |
| if (onCompleted != null) { |
| onCompleted.run(); |
| } |
| LOG.error("start should not be called from state: " + _serviceState); |
| } |
| } |
| }); |
| } |
| |
| final public void stop(final Runnable onCompleted) { |
| queue().execute(new Runnable() { |
| public void run() { |
| if (_serviceState instanceof STARTED) { |
| final STOPPING state = new STOPPING(); |
| state.add(onCompleted); |
| _serviceState = state; |
| _stop(new Runnable() { |
| public void run() { |
| _serviceState = new STOPPED(); |
| state.done(); |
| } |
| }); |
| } else if (_serviceState instanceof STOPPING) { |
| _serviceState.add(onCompleted); |
| } else if (_serviceState instanceof STOPPED) { |
| if (onCompleted != null) { |
| onCompleted.run(); |
| } |
| } else { |
| if (onCompleted != null) { |
| onCompleted.run(); |
| } |
| LOG.error("stop should not be called from state: " + _serviceState); |
| } |
| } |
| }); |
| } |
| |
| protected State getServiceState() { |
| return _serviceState; |
| } |
| |
| public void connected(SocketChannel channel) throws IOException, Exception { |
| this.channel = channel; |
| |
| if( codec !=null ) { |
| initializeCodec(); |
| } |
| |
| this.channel.configureBlocking(false); |
| this.remoteAddress = channel.socket().getRemoteSocketAddress().toString(); |
| channel.socket().setSoLinger(true, 0); |
| channel.socket().setTcpNoDelay(true); |
| |
| this.socketState = new CONNECTED(); |
| } |
| |
| protected void initializeCodec() { |
| codec.setReadableByteChannel(readChannel()); |
| codec.setWritableByteChannel(writeChannel()); |
| } |
| |
| public void connecting(URI remoteLocation, URI localLocation) throws IOException, Exception { |
| this.channel = SocketChannel.open(); |
| this.channel.configureBlocking(false); |
| this.remoteLocation = remoteLocation; |
| this.localLocation = localLocation; |
| |
| if (localLocation != null) { |
| InetSocketAddress localAddress = new InetSocketAddress(InetAddress.getByName(localLocation.getHost()), localLocation.getPort()); |
| channel.socket().bind(localAddress); |
| } |
| |
| String host = resolveHostName(remoteLocation.getHost()); |
| InetSocketAddress remoteAddress = new InetSocketAddress(host, remoteLocation.getPort()); |
| channel.connect(remoteAddress); |
| this.socketState = new CONNECTING(); |
| } |
| |
| |
| public DispatchQueue queue() { |
| return dispatchQueue; |
| } |
| |
| public void setDispatchQueue(DispatchQueue queue) { |
| this.dispatchQueue = queue; |
| } |
| |
| public void _start(Runnable onCompleted) { |
| try { |
| if (socketState.isConnecting()) { |
| trace("connecting..."); |
| // this allows the connect to complete.. |
| readSource = Dispatch.createSource(channel, SelectionKey.OP_CONNECT, dispatchQueue); |
| readSource.setEventHandler(new Runnable() { |
| public void run() { |
| if (!(getServiceState().isStarted())) { |
| return; |
| } |
| try { |
| trace("connected."); |
| channel.finishConnect(); |
| readSource.setCancelHandler(null); |
| readSource.cancel(); |
| readSource=null; |
| socketState = new CONNECTED(); |
| onConnected(); |
| } catch (IOException e) { |
| onTransportFailure(e); |
| } |
| } |
| }); |
| readSource.setCancelHandler(CANCEL_HANDLER); |
| readSource.resume(); |
| |
| } else if (socketState.isConnected()) { |
| dispatchQueue.execute(new Runnable() { |
| public void run() { |
| try { |
| trace("was connected."); |
| onConnected(); |
| } catch (IOException e) { |
| onTransportFailure(e); |
| } |
| } |
| }); |
| } else { |
| System.err.println("cannot be started. socket state is: "+socketState); |
| } |
| } finally { |
| if( onCompleted!=null ) { |
| onCompleted.run(); |
| } |
| } |
| } |
| |
| public void _stop(final Runnable onCompleted) { |
| trace("stopping.. at state: "+socketState); |
| socketState.onStop(onCompleted); |
| } |
| |
| protected String resolveHostName(String host) throws UnknownHostException { |
| try { |
| if(isUseLocalHost()) { |
| String localName = InetAddress.getLocalHost().getHostName(); |
| if (localName != null) { |
| if (localName.equals(host)) { |
| return "localhost"; |
| } |
| } |
| } |
| } catch (Exception e) { |
| LOG.warn("Failed to resolve local host address", e); |
| } |
| return host; |
| } |
| |
| protected void onConnected() throws IOException { |
| |
| readSource = Dispatch.createSource(channel, SelectionKey.OP_READ, dispatchQueue); |
| writeSource = Dispatch.createSource(channel, SelectionKey.OP_WRITE, dispatchQueue); |
| |
| readSource.setCancelHandler(CANCEL_HANDLER); |
| writeSource.setCancelHandler(CANCEL_HANDLER); |
| |
| readSource.setEventHandler(new Runnable() { |
| public void run() { |
| drainInbound(); |
| } |
| }); |
| writeSource.setEventHandler(new Runnable() { |
| public void run() { |
| drainOutbound(); |
| } |
| }); |
| |
| if( max_read_rate!=0 || max_write_rate!=0 ) { |
| rateLimitingChannel = new RateLimitingChannel(); |
| scheduleRateAllowanceReset(); |
| } |
| |
| remoteAddress = channel.socket().getRemoteSocketAddress().toString(); |
| listener.onTransportConnected(this); |
| } |
| |
| private void scheduleRateAllowanceReset() { |
| dispatchQueue.executeAfter(1, TimeUnit.SECONDS, new Runnable(){ |
| public void run() { |
| if (!socketState.isConnected()) { |
| return; |
| } |
| rateLimitingChannel.resetAllowance(); |
| scheduleRateAllowanceReset(); |
| } |
| }); |
| } |
| |
| private void dispose() { |
| if( readSource!=null ) { |
| readSource.cancel(); |
| readSource=null; |
| } |
| |
| if( writeSource!=null ) { |
| writeSource.cancel(); |
| writeSource=null; |
| } |
| this.codec = null; |
| } |
| |
| public void onTransportFailure(IOException error) { |
| listener.onTransportFailure(this, error); |
| socketState.onCanceled(); |
| } |
| |
| |
| public boolean full() { |
| return codec.full(); |
| } |
| |
| public boolean offer(Object command) { |
| assert Dispatch.getCurrentQueue() == dispatchQueue; |
| try { |
| if (!socketState.isConnected()) { |
| throw new IOException("Not connected."); |
| } |
| if (!getServiceState().isStarted()) { |
| throw new IOException("Not running."); |
| } |
| |
| ProtocolCodec.BufferState rc = codec.write(command); |
| switch (rc ) { |
| case FULL: |
| return false; |
| default: |
| if( drained ) { |
| drained = false; |
| resumeWrite(); |
| } |
| return true; |
| } |
| } catch (IOException e) { |
| onTransportFailure(e); |
| return false; |
| } |
| |
| } |
| |
| |
| /** |
| * |
| */ |
| protected void drainOutbound() { |
| assert Dispatch.getCurrentQueue() == dispatchQueue; |
| if (!getServiceState().isStarted() || !socketState.isConnected()) { |
| return; |
| } |
| try { |
| if( codec.flush() == ProtocolCodec.BufferState.WAS_EMPTY && flush() ) { |
| if( !drained ) { |
| drained = true; |
| suspendWrite(); |
| listener.onRefill(this); |
| } |
| } |
| } catch (IOException e) { |
| onTransportFailure(e); |
| } |
| } |
| |
| protected boolean flush() throws IOException { |
| return true; |
| } |
| |
| protected void drainInbound() { |
| if (!getServiceState().isStarted() || readSource.isSuspended()) { |
| return; |
| } |
| try { |
| long initial = codec.getReadCounter(); |
| // Only process up to 64k worth of data at a time so we can give |
| // other connections a chance to process their requests. |
| while( codec.getReadCounter()-initial < 1024*64 ) { |
| Object command = codec.read(); |
| if ( command!=null ) { |
| try { |
| listener.onTransportCommand(this, command); |
| } catch (Throwable e) { |
| onTransportFailure(new IOException("Transport listener failure.")); |
| } |
| |
| // the transport may be suspended after processing a command. |
| if (getServiceState().isStopped() || readSource.isSuspended()) { |
| return; |
| } |
| } else { |
| return; |
| } |
| } |
| } catch (IOException e) { |
| onTransportFailure(e); |
| } |
| } |
| |
| |
| public String getRemoteAddress() { |
| return remoteAddress; |
| } |
| |
| public void suspendRead() { |
| if( isConnected() && readSource!=null ) { |
| readSource.suspend(); |
| } |
| } |
| |
| |
| public void resumeRead() { |
| if( isConnected() && readSource!=null ) { |
| if( rateLimitingChannel!=null ) { |
| rateLimitingChannel.resumeRead(); |
| } else { |
| _resumeRead(); |
| } |
| } |
| } |
| private void _resumeRead() { |
| readSource.resume(); |
| dispatchQueue.execute(new Runnable(){ |
| public void run() { |
| drainInbound(); |
| } |
| }); |
| } |
| |
| protected void suspendWrite() { |
| if( isConnected() && writeSource!=null ) { |
| writeSource.suspend(); |
| } |
| } |
| protected void resumeWrite() { |
| if( isConnected() && writeSource!=null ) { |
| writeSource.resume(); |
| dispatchQueue.execute(new Runnable(){ |
| public void run() { |
| drainOutbound(); |
| } |
| }); |
| } |
| } |
| |
| public TransportListener getTransportListener() { |
| return listener; |
| } |
| |
| public void setTransportListener(TransportListener listener) { |
| this.listener = listener; |
| } |
| |
| public ProtocolCodec getProtocolCodec() { |
| return codec; |
| } |
| |
| public void setProtocolCodec(ProtocolCodec protocolCodec) { |
| this.codec = protocolCodec; |
| if( channel!=null && codec!=null ) { |
| initializeCodec(); |
| } |
| } |
| |
| public boolean isConnected() { |
| return socketState.isConnected(); |
| } |
| |
| public boolean isDisposed() { |
| return getServiceState().isStopped() || getServiceState().isStopping(); |
| } |
| |
| public void setSocketOptions(Map<String, Object> socketOptions) { |
| this.socketOptions = socketOptions; |
| } |
| |
| public boolean isUseLocalHost() { |
| return useLocalHost; |
| } |
| |
| /** |
| * Sets whether 'localhost' or the actual local host name should be used to |
| * make local connections. On some operating systems such as Macs its not |
| * possible to connect as the local host name so localhost is better. |
| */ |
| public void setUseLocalHost(boolean useLocalHost) { |
| this.useLocalHost = useLocalHost; |
| } |
| |
| |
| private void trace(String message) { |
| if( LOG.isTraceEnabled() ) { |
| final String label = dispatchQueue.getLabel(); |
| if( label !=null ) { |
| LOG.trace(label +" | "+message); |
| } else { |
| LOG.trace(message); |
| } |
| } |
| } |
| |
| public SocketChannel getSocketChannel() { |
| return channel; |
| } |
| |
| public ReadableByteChannel readChannel() { |
| if(rateLimitingChannel!=null) { |
| return rateLimitingChannel; |
| } else { |
| return channel; |
| } |
| } |
| |
| public WritableByteChannel writeChannel() { |
| if(rateLimitingChannel!=null) { |
| return rateLimitingChannel; |
| } else { |
| return channel; |
| } |
| } |
| |
| public int getMax_read_rate() { |
| return max_read_rate; |
| } |
| |
| public void setMax_read_rate(int max_read_rate) { |
| this.max_read_rate = max_read_rate; |
| } |
| |
| public int getMax_write_rate() { |
| return max_write_rate; |
| } |
| |
| public void setMax_write_rate(int max_write_rate) { |
| this.max_write_rate = max_write_rate; |
| } |
| |
| class RateLimitingChannel implements ReadableByteChannel, WritableByteChannel { |
| |
| int read_allowance = max_read_rate; |
| boolean read_suspended = false; |
| int read_resume_counter = 0; |
| int write_allowance = max_write_rate; |
| boolean write_suspended = false; |
| |
| public void resetAllowance() { |
| if( read_allowance != max_read_rate || write_allowance != max_write_rate) { |
| read_allowance = max_read_rate; |
| write_allowance = max_write_rate; |
| if( write_suspended ) { |
| write_suspended = false; |
| resumeWrite(); |
| } |
| if( read_suspended ) { |
| read_suspended = false; |
| resumeRead(); |
| for( int i=0; i < read_resume_counter ; i++ ) { |
| resumeRead(); |
| } |
| } |
| } |
| } |
| |
| public int read(ByteBuffer dst) throws IOException { |
| if( max_read_rate==0 ) { |
| return channel.read(dst); |
| } else { |
| int remaining = dst.remaining(); |
| if( read_allowance ==0 || remaining ==0 ) { |
| return 0; |
| } |
| |
| int reduction = 0; |
| if( remaining > read_allowance) { |
| reduction = remaining - read_allowance; |
| dst.limit(dst.limit() - reduction); |
| } |
| int rc; |
| try { |
| rc = channel.read(dst); |
| read_allowance -= rc; |
| } finally { |
| if( reduction!=0 ) { |
| if( dst.remaining() == 0 ) { |
| // we need to suspend the read now until we get |
| // a new allowance.. |
| readSource.suspend(); |
| read_suspended = true; |
| } |
| dst.limit(dst.limit() + reduction); |
| } |
| } |
| return rc; |
| } |
| } |
| |
| public int write(ByteBuffer src) throws IOException { |
| if( max_write_rate==0 ) { |
| return channel.write(src); |
| } else { |
| int remaining = src.remaining(); |
| if( write_allowance ==0 || remaining ==0 ) { |
| return 0; |
| } |
| |
| int reduction = 0; |
| if( remaining > write_allowance) { |
| reduction = remaining - write_allowance; |
| src.limit(src.limit() - reduction); |
| } |
| int rc; |
| try { |
| rc = channel.write(src); |
| write_allowance -= rc; |
| } finally { |
| if( reduction!=0 ) { |
| if( src.remaining() == 0 ) { |
| // we need to suspend the read now until we get |
| // a new allowance.. |
| write_suspended = true; |
| suspendWrite(); |
| } |
| src.limit(src.limit() + reduction); |
| } |
| } |
| return rc; |
| } |
| } |
| |
| public boolean isOpen() { |
| return channel.isOpen(); |
| } |
| |
| public void close() throws IOException { |
| channel.close(); |
| } |
| |
| public void resumeRead() { |
| if( read_suspended ) { |
| read_resume_counter += 1; |
| } else { |
| _resumeRead(); |
| } |
| } |
| |
| } |
| |
| // |
| // Transport states |
| // |
| |
| public static abstract class State { |
| LinkedList<Runnable> callbacks = new LinkedList<>(); |
| |
| void add(Runnable r) { |
| if (r != null) { |
| callbacks.add(r); |
| } |
| } |
| |
| void done() { |
| for (Runnable callback : callbacks) { |
| callback.run(); |
| } |
| } |
| |
| public String toString() { |
| return getClass().getSimpleName(); |
| } |
| |
| boolean is(Class<? extends State> clazz) { |
| return getClass() == clazz; |
| } |
| |
| public boolean isCreated() { |
| return is(CREATED.class); |
| } |
| |
| public boolean isStarting() { |
| return is(STARTING.class); |
| } |
| |
| public boolean isStarted() { |
| return is(STARTED.class); |
| } |
| |
| public boolean isStopping() { |
| return is(STOPPING.class); |
| } |
| |
| public boolean isStopped() { |
| return is(STOPPED.class); |
| } |
| } |
| |
| public static final class CREATED extends State { |
| } |
| |
| public static final class STARTING extends State { |
| } |
| |
| public static final class STARTED extends State { |
| } |
| |
| public static final class STOPPING extends State { |
| } |
| |
| public static final class STOPPED extends State { |
| } |
| |
| |
| // |
| // Socket states |
| // |
| |
| abstract static class SocketState { |
| void onStop(Runnable onCompleted) { |
| } |
| void onCanceled() { |
| } |
| boolean is(Class<? extends SocketState> clazz) { |
| return getClass()==clazz; |
| } |
| boolean isConnecting() { |
| return is(CONNECTING.class); |
| } |
| boolean isConnected() { |
| return is(CONNECTED.class); |
| } |
| } |
| |
| class DISCONNECTED extends SocketState { |
| } |
| |
| class CONNECTING extends SocketState { |
| void onStop(Runnable onCompleted) { |
| trace("CONNECTING.onStop"); |
| CANCELING state = new CANCELING(); |
| socketState = state; |
| state.onStop(onCompleted); |
| } |
| void onCanceled() { |
| trace("CONNECTING.onCanceled"); |
| CANCELING state = new CANCELING(); |
| socketState = state; |
| state.onCanceled(); |
| } |
| } |
| |
| class CONNECTED extends SocketState { |
| void onStop(Runnable onCompleted) { |
| trace("CONNECTED.onStop"); |
| CANCELING state = new CANCELING(); |
| socketState = state; |
| state.add(createDisconnectTask()); |
| state.onStop(onCompleted); |
| } |
| void onCanceled() { |
| trace("CONNECTED.onCanceled"); |
| CANCELING state = new CANCELING(); |
| socketState = state; |
| state.add(createDisconnectTask()); |
| state.onCanceled(); |
| } |
| Runnable createDisconnectTask() { |
| return new Runnable(){ |
| public void run() { |
| listener.onTransportDisconnected(TcpTransport.this); |
| } |
| }; |
| } |
| } |
| |
| class CANCELING extends SocketState { |
| private LinkedList<Runnable> runnables = new LinkedList<>(); |
| private int remaining; |
| private boolean dispose; |
| |
| public CANCELING() { |
| if( readSource!=null ) { |
| remaining++; |
| readSource.cancel(); |
| } |
| if( writeSource!=null ) { |
| remaining++; |
| writeSource.cancel(); |
| } |
| } |
| void onStop(Runnable onCompleted) { |
| trace("CANCELING.onCompleted"); |
| add(onCompleted); |
| dispose = true; |
| } |
| void add(Runnable onCompleted) { |
| if( onCompleted!=null ) { |
| runnables.add(onCompleted); |
| } |
| } |
| void onCanceled() { |
| trace("CANCELING.onCanceled"); |
| remaining--; |
| if( remaining!=0 ) { |
| return; |
| } |
| try { |
| channel.close(); |
| } catch (IOException ignore) { |
| } |
| socketState = new CANCELED(dispose); |
| for (Runnable runnable : runnables) { |
| runnable.run(); |
| } |
| if (dispose) { |
| dispose(); |
| } |
| } |
| } |
| |
| class CANCELED extends SocketState { |
| private boolean disposed; |
| |
| public CANCELED(boolean disposed) { |
| this.disposed=disposed; |
| } |
| |
| void onStop(Runnable onCompleted) { |
| trace("CANCELED.onStop"); |
| if( !disposed ) { |
| disposed = true; |
| dispose(); |
| } |
| onCompleted.run(); |
| } |
| } |
| |
| |
| } |