blob: 2898be2d91f89a543f292f9d6ffbd3f61c49922e [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.qpid.proton.hawtdispatch.impl;
import org.apache.qpid.proton.hawtdispatch.api.AmqpConnectOptions;
import org.apache.qpid.proton.hawtdispatch.api.Callback;
import org.apache.qpid.proton.hawtdispatch.api.ChainedCallback;
import org.apache.qpid.proton.hawtdispatch.api.TransportState;
import org.apache.qpid.proton.engine.*;
import org.apache.qpid.proton.engine.impl.EngineFactoryImpl;
import org.apache.qpid.proton.engine.impl.ProtocolTracer;
import org.fusesource.hawtbuf.Buffer;
import org.fusesource.hawtbuf.DataByteArrayOutputStream;
import org.fusesource.hawtbuf.UTF8Buffer;
import org.fusesource.hawtdispatch.*;
import org.fusesource.hawtdispatch.transport.DefaultTransportListener;
import org.fusesource.hawtdispatch.transport.SslTransport;
import org.fusesource.hawtdispatch.transport.TcpTransport;
import org.fusesource.hawtdispatch.transport.Transport;
import java.io.IOException;
import java.net.URI;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.LinkedList;
import static org.apache.qpid.proton.hawtdispatch.api.TransportState.*;
import static org.fusesource.hawtdispatch.Dispatch.NOOP;
/**
* @author <a href="http://hiramchirino.com">Hiram Chirino</a>
*/
public class AmqpTransport extends WatchBase {
private TransportState state = CREATED;
final DispatchQueue queue;
final ProtonJConnection connection;
private final EngineFactoryImpl engineFactory = new EngineFactoryImpl();
Transport hawtdispatchTransport;
ProtonJTransport protonTransport;
Throwable failure;
CustomDispatchSource<Defer,LinkedList<Defer>> defers;
public static final EnumSet<EndpointState> ALL_SET = EnumSet.allOf(EndpointState.class);
private AmqpTransport(DispatchQueue queue) {
this.queue = queue;
this.connection = engineFactory.createConnection();
defers = Dispatch.createSource(EventAggregators.<Defer>linkedList(), this.queue);
defers.setEventHandler(new Task(){
public void run() {
for( Defer defer: defers.getData() ) {
assert defer.defered = true;
defer.defered = false;
defer.run();
}
}
});
defers.resume();
}
static public AmqpTransport connect(AmqpConnectOptions options) {
AmqpConnectOptions opts = options.clone();
if( opts.getDispatchQueue() == null ) {
opts.setDispatchQueue(Dispatch.createQueue());
}
if( opts.getBlockingExecutor() == null ) {
opts.setBlockingExecutor(AmqpConnectOptions.getBlockingThreadPool());
}
return new AmqpTransport(opts.getDispatchQueue()).connecting(opts);
}
private AmqpTransport connecting(final AmqpConnectOptions options) {
assert state == CREATED;
try {
state = CONNECTING;
if( options.getLocalContainerId()!=null ) {
connection.setLocalContainerId(options.getLocalContainerId());
}
if( options.getRemoteContainerId()!=null ) {
connection.setContainer(options.getRemoteContainerId());
}
connection.setHostname(options.getHost().getHost());
Callback<Void> onConnect = new Callback<Void>() {
@Override
public void onSuccess(Void value) {
if( state == CONNECTED ) {
hawtdispatchTransport.setTransportListener(new AmqpTransportListener());
fireWatches();
}
}
@Override
public void onFailure(Throwable value) {
if( state == CONNECTED ) {
failure = value;
disconnect();
fireWatches();
}
}
};
if( options.getUser()!=null ) {
onConnect = new SaslClientHandler(options, onConnect);
}
createTransport(options, onConnect);
} catch (Throwable e) {
failure = e;
}
fireWatches();
return this;
}
public TransportState getState() {
return state;
}
/**
* Creates and start a transport to the AMQP server. Passes it to the onConnect
* once the transport is connected.
*
* @param onConnect
* @throws Exception
*/
void createTransport(AmqpConnectOptions options, final Callback<Void> onConnect) throws Exception {
final TcpTransport transport;
if( options.getSslContext() !=null ) {
SslTransport ssl = new SslTransport();
ssl.setSSLContext(options.getSslContext());
transport = ssl;
} else {
transport = new TcpTransport();
}
URI host = options.getHost();
if( host.getPort() == -1 ) {
if( options.getSslContext()!=null ) {
host = new URI(host.getScheme()+"://"+host.getHost()+":5672");
} else {
host = new URI(host.getScheme()+"://"+host.getHost()+":5671");
}
}
transport.setBlockingExecutor(options.getBlockingExecutor());
transport.setDispatchQueue(options.getDispatchQueue());
transport.setMaxReadRate(options.getMaxReadRate());
transport.setMaxWriteRate(options.getMaxWriteRate());
transport.setReceiveBufferSize(options.getReceiveBufferSize());
transport.setSendBufferSize(options.getSendBufferSize());
transport.setTrafficClass(options.getTrafficClass());
transport.setUseLocalHost(options.isUseLocalHost());
transport.connecting(host, options.getLocalAddress());
transport.setTransportListener(new DefaultTransportListener(){
public void onTransportConnected() {
if(state==CONNECTING) {
state = CONNECTED;
onConnect.onSuccess(null);
transport.resumeRead();
}
}
public void onTransportFailure(final IOException error) {
if(state==CONNECTING) {
onConnect.onFailure(error);
}
}
});
transport.connecting(host, options.getLocalAddress());
bind(transport);
transport.start(NOOP);
}
class SaslClientHandler extends ChainedCallback<Void, Void> {
private final AmqpConnectOptions options;
public SaslClientHandler(AmqpConnectOptions options, Callback<Void> next) {
super(next);
this.options = options;
}
public void onSuccess(final Void value) {
final Sasl s = protonTransport.sasl();
s.client();
pumpOut();
hawtdispatchTransport.setTransportListener(new AmqpTransportListener() {
Sasl sasl = s;
@Override
void process() {
if (sasl != null) {
sasl = processSaslEvent(sasl);
if (sasl == null) {
// once sasl handshake is done.. we need to read the protocol header again.
((AmqpProtocolCodec) hawtdispatchTransport.getProtocolCodec()).readProtocolHeader();
}
}
}
@Override
public void onTransportFailure(IOException error) {
next.onFailure(error);
}
@Override
void onFailure(Throwable error) {
next.onFailure(error);
}
boolean authSent = false;
private Sasl processSaslEvent(Sasl sasl) {
if (sasl.getOutcome() == Sasl.SaslOutcome.PN_SASL_OK) {
next.onSuccess(null);
return null;
}
HashSet<String> mechanisims = new HashSet<String>(Arrays.asList(sasl.getRemoteMechanisms()));
if (!authSent && !mechanisims.isEmpty()) {
if (!mechanisims.contains("PLAIN")) {
next.onFailure(Support.illegalState("Remote does not support plain password authentication."));
return null;
}
authSent = true;
DataByteArrayOutputStream os = new DataByteArrayOutputStream();
try {
os.write(new UTF8Buffer(options.getUser()));
os.writeByte(0);
if (options.getPassword() != null) {
os.write(new UTF8Buffer(options.getPassword()));
os.writeByte(0);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
Buffer buffer = os.toBuffer();
sasl.setMechanisms(new String[]{"PLAIN"});
sasl.send(buffer.data, buffer.offset, buffer.length);
}
return sasl;
}
});
}
}
class SaslServerListener extends AmqpTransportListener {
Sasl sasl;
@Override
public void onTransportCommand(Object command) {
try {
if (command.getClass() == AmqpHeader.class) {
AmqpHeader header = (AmqpHeader)command;
switch( header.getProtocolId() ) {
case 3: // Client will be using SASL for auth..
if( listener!=null ) {
sasl = listener.processSaslConnect(protonTransport);
break;
}
default:
AmqpTransportListener listener = new AmqpTransportListener();
hawtdispatchTransport.setTransportListener(listener);
listener.onTransportCommand(command);
return;
}
command = header.getBuffer();
}
} catch (Exception e) {
onFailure(e);
}
super.onTransportCommand(command);
}
@Override
void process() {
if (sasl != null) {
sasl = listener.processSaslEvent(sasl);
}
if (sasl == null) {
// once sasl handshake is done.. we need to read the protocol header again.
((AmqpProtocolCodec) hawtdispatchTransport.getProtocolCodec()).readProtocolHeader();
hawtdispatchTransport.setTransportListener(new AmqpTransportListener());
}
}
}
static public AmqpTransport accept(Transport transport) {
return new AmqpTransport(transport.getDispatchQueue()).accepted(transport);
}
private AmqpTransport accepted(final Transport transport) {
state = CONNECTED;
bind(transport);
hawtdispatchTransport.setTransportListener(new SaslServerListener());
return this;
}
private void bind(final Transport transport) {
this.hawtdispatchTransport = transport;
this.protonTransport = engineFactory.createTransport();
this.protonTransport.bind(connection);
if( transport.getProtocolCodec()==null ) {
try {
transport.setProtocolCodec(new AmqpProtocolCodec());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
public void defer(Defer defer) {
if( !defer.defered ) {
defer.defered = true;
defers.merge(defer);
}
}
public void pumpOut() {
assertExecuting();
defer(deferedPumpOut);
}
private Defer deferedPumpOut = new Defer() {
public void run() {
doPumpOut();
}
};
private void doPumpOut() {
switch(state) {
case CONNECTING:
case CONNECTED:
break;
default:
return;
}
int size = hawtdispatchTransport.getProtocolCodec().getWriteBufferSize();
byte data[] = new byte[size];
boolean done = false;
int pumped = 0;
while( !done && !hawtdispatchTransport.full() ) {
int count = protonTransport.output(data, 0, size);
if( count > 0 ) {
pumped += count;
boolean accepted = hawtdispatchTransport.offer(new Buffer(data, 0, count));
assert accepted: "Should be accepted since the transport was not full";
} else {
done = true;
}
}
if( pumped > 0 && !hawtdispatchTransport.full() ) {
listener.processRefill();
}
}
public Sasl sasl;
public void fireListenerEvents() {
fireWatches();
if( sasl!=null ) {
sasl = listener.processSaslEvent(sasl);
if( sasl==null ) {
// once sasl handshake is done.. we need to read the protocol header again.
((AmqpProtocolCodec)this.hawtdispatchTransport.getProtocolCodec()).readProtocolHeader();
}
}
context(connection).fireListenerEvents(listener);
Session session = connection.sessionHead(ALL_SET, ALL_SET);
while(session != null)
{
context(session).fireListenerEvents(listener);
session = session.next(ALL_SET, ALL_SET);
}
Link link = connection.linkHead(ALL_SET, ALL_SET);
while(link != null)
{
context(link).fireListenerEvents(listener);
link = link.next(ALL_SET, ALL_SET);
}
Delivery delivery = connection.getWorkHead();
while(delivery != null)
{
listener.processDelivery(delivery);
delivery = delivery.getWorkNext();
}
listener.processRefill();
}
public ProtonJConnection connection() {
return connection;
}
AmqpListener listener = new AmqpListener();
public AmqpListener getListener() {
return listener;
}
public void setListener(AmqpListener listener) {
this.listener = listener;
}
public EndpointContext context(Endpoint endpoint) {
EndpointContext context = (EndpointContext) endpoint.getContext();
if( context == null ) {
context = new EndpointContext(this, endpoint);
endpoint.setContext(context);
}
return context;
}
class AmqpTransportListener extends DefaultTransportListener {
@Override
public void onTransportConnected() {
if( listener!=null ) {
listener.processTransportConnected();
}
}
@Override
public void onRefill() {
if( listener!=null ) {
listener.processRefill();
}
}
@Override
public void onTransportCommand(Object command) {
if( state != CONNECTED ) {
return;
}
try {
Buffer buffer;
if (command.getClass() == AmqpHeader.class) {
buffer = ((AmqpHeader) command).getBuffer();
} else {
buffer = (Buffer) command;
}
protonTransport.input(buffer.data, buffer.offset, buffer.length);
process();
pumpOut();
} catch (Exception e) {
onFailure(e);
}
}
void process() {
fireListenerEvents();
}
@Override
public void onTransportFailure(IOException error) {
if( state==CONNECTED ) {
failure = error;
if( listener!=null ) {
listener.processTransportFailure(error);
fireWatches();
}
}
}
void onFailure(Throwable error) {
failure = error;
if( listener!=null ) {
listener.processFailure(error);
fireWatches();
}
}
}
public void disconnect() {
assertExecuting();
if( state == CONNECTING || state==CONNECTED) {
state = DISCONNECTING;
if( hawtdispatchTransport!=null ) {
hawtdispatchTransport.stop(new Task(){
public void run() {
state = DISCONNECTED;
hawtdispatchTransport = null;
protonTransport = null;
fireWatches();
}
});
}
}
}
public DispatchQueue queue() {
return queue;
}
public void assertExecuting() {
queue().assertExecuting();
}
public void onTransportConnected(final Callback<Void> cb) {
addWatch(new Watch() {
@Override
public boolean execute() {
if( failure !=null ) {
cb.onFailure(failure);
return true;
}
if( state!=CONNECTING ) {
cb.onSuccess(null);
return true;
}
return false;
}
});
}
public void onTransportDisconnected(final Callback<Void> cb) {
addWatch(new Watch() {
@Override
public boolean execute() {
if( state==DISCONNECTED ) {
cb.onSuccess(null);
return true;
}
return false;
}
});
}
public void onTransportFailure(final Callback<Throwable> cb) {
addWatch(new Watch() {
@Override
public boolean execute() {
if( failure!=null ) {
cb.onSuccess(failure);
return true;
}
return false;
}
});
}
public Throwable getFailure() {
return failure;
}
public void setProtocolTracer(ProtocolTracer protocolTracer) {
protonTransport.setProtocolTracer(protocolTracer);
}
public ProtocolTracer getProtocolTracer() {
return protonTransport.getProtocolTracer();
}
}