blob: 8b45aa94759bd9b8057f0243ddf65da1acc1ce64 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
using Apache.NMS.ActiveMQ.Commands;
using Apache.NMS.ActiveMQ.Transport;
using Apache.NMS;
using System;
using System.Collections;
using System.IO;
using System.Text;
namespace Apache.NMS.ActiveMQ.Transport.Stomp
{
/// <summary>
/// Implements the <a href="http://stomp.codehaus.org/">STOMP</a> protocol.
/// </summary>
public class StompWireFormat : IWireFormat
{
private Encoding encoding = new UTF8Encoding();
private ITransport transport;
private IDictionary consumers = Hashtable.Synchronized(new Hashtable());
public StompWireFormat()
{
}
public ITransport Transport {
get { return transport; }
set { transport = value; }
}
public int Version {
get { return 1; }
}
public void Marshal(Object o, BinaryWriter binaryWriter)
{
Tracer.Debug(">>>> " + o);
StompFrameStream ds = new StompFrameStream(binaryWriter, encoding);
if (o is ConnectionInfo)
{
WriteConnectionInfo((ConnectionInfo) o, ds);
}
else if (o is ActiveMQMessage)
{
WriteMessage((ActiveMQMessage) o, ds);
}
else if (o is ConsumerInfo)
{
WriteConsumerInfo((ConsumerInfo) o, ds);
}
else if (o is MessageAck)
{
WriteMessageAck((MessageAck) o, ds);
}
else if (o is TransactionInfo)
{
WriteTransactionInfo((TransactionInfo) o, ds);
}
else if (o is ShutdownInfo)
{
WriteShutdownInfo((ShutdownInfo) o, ds);
}
else if (o is RemoveInfo)
{
WriteRemoveInfo((RemoveInfo) o, ds);
}
else if (o is Command)
{
Command command = o as Command;
if (command.ResponseRequired)
{
Response response = new Response();
response.CorrelationId = command.CommandId;
SendCommand(response);
Tracer.Debug("#### Autorespond to command: " + o.GetType());
}
}
else
{
Tracer.Debug("#### Ignored command: " + o.GetType());
}
}
internal String ReadLine(BinaryReader dis)
{
MemoryStream ms = new MemoryStream();
while (true)
{
int nextChar = dis.Read();
if (nextChar < 0)
{
throw new IOException("Peer closed the stream.");
}
if( nextChar == 10 )
{
break;
}
ms.WriteByte((byte)nextChar);
}
byte[] data = ms.ToArray();
return encoding.GetString(data, 0, data.Length);
}
public Object Unmarshal(BinaryReader dis)
{
string command;
do {
command = ReadLine(dis);
}
while (command == "");
Tracer.Debug("<<<< command: " + command);
IDictionary headers = new Hashtable();
string line;
while ((line = ReadLine(dis)) != "")
{
int idx = line.IndexOf(':');
if (idx > 0)
{
string key = line.Substring(0, idx);
string value = line.Substring(idx + 1);
headers[key] = value;
Tracer.Debug("<<<< header: " + key + " = " + value);
}
else
{
// lets ignore this bad header!
}
}
byte[] content = null;
string length = ToString(headers["content-length"]);
if (length != null)
{
int size = Int32.Parse(length);
content = dis.ReadBytes(size);
// Read the terminating NULL byte for this frame.
int nullByte = dis.Read();
if(nullByte != 0)
{
Tracer.Debug("<<<< error reading frame null byte.");
}
}
else
{
MemoryStream ms = new MemoryStream();
int nextChar;
while((nextChar = dis.Read()) != 0)
{
if( nextChar < 0 )
{
// EOF ??
break;
}
ms.WriteByte((byte)nextChar);
}
content = ms.ToArray();
}
Object answer = CreateCommand(command, headers, content);
Tracer.Debug("<<<< received: " + answer);
return answer;
}
protected virtual Object CreateCommand(string command, IDictionary headers, byte[] content)
{
if(command == "RECEIPT" || command == "CONNECTED")
{
string text = RemoveHeader(headers, "receipt-id");
if(text != null)
{
Response answer = new Response();
if(text.StartsWith("ignore:"))
{
text = text.Substring("ignore:".Length);
}
answer.CorrelationId = Int32.Parse(text);
return answer;
}
else if(command == "CONNECTED")
{
text = RemoveHeader(headers, "response-id");
if (text != null)
{
Response answer = new Response();
answer.CorrelationId = Int32.Parse(text);
return answer;
}
}
}
else if(command == "ERROR")
{
string text = RemoveHeader(headers, "receipt-id");
if(text != null && text.StartsWith("ignore:"))
{
Response answer = new Response();
answer.CorrelationId = Int32.Parse(text.Substring("ignore:".Length));
return answer;
}
else
{
ExceptionResponse answer = new ExceptionResponse();
if(text != null)
{
answer.CorrelationId = Int32.Parse(text);
}
BrokerError error = new BrokerError();
error.Message = RemoveHeader(headers, "message");
error.ExceptionClass = RemoveHeader(headers, "exceptionClass");
// TODO is this the right header?
answer.Exception = error;
return answer;
}
}
else if (command == "MESSAGE")
{
return ReadMessage(command, headers, content);
}
Tracer.Error("Unknown command: " + command + " headers: " + headers);
return null;
}
protected virtual Command ReadMessage(string command, IDictionary headers, byte[] content)
{
ActiveMQMessage message = null;
if (headers.Contains("content-length"))
{
message = new ActiveMQBytesMessage();
message.Content = content;
}
else
{
message = new ActiveMQTextMessage(encoding.GetString(content, 0, content.Length));
}
// TODO now lets set the various headers
message.Type = RemoveHeader(headers, "type");
message.Destination = StompHelper.ToDestination(RemoveHeader(headers, "destination"));
message.ReplyTo = StompHelper.ToDestination(RemoveHeader(headers, "reply-to"));
message.TargetConsumerId = StompHelper.ToConsumerId(RemoveHeader(headers, "subscription"));
message.CorrelationId = RemoveHeader(headers, "correlation-id");
message.MessageId = StompHelper.ToMessageId(RemoveHeader(headers, "message-id"));
message.Persistent = StompHelper.ToBool(RemoveHeader(headers, "persistent"), true);
string header = RemoveHeader(headers, "priority");
if (header != null) message.Priority = Byte.Parse(header);
header = RemoveHeader(headers, "timestamp");
if (header != null) message.Timestamp = Int64.Parse(header);
header = RemoveHeader(headers, "expires");
if (header != null) message.Expiration = Int64.Parse(header);
// now lets add the generic headers
foreach (string key in headers.Keys)
{
Object value = headers[key];
if (value != null)
{
// lets coerce some standard header extensions
if (key == "NMSXGroupSeq")
{
value = Int32.Parse(value.ToString());
}
}
message.Properties[key] = value;
}
MessageDispatch dispatch = new MessageDispatch();
dispatch.Message = message;
dispatch.ConsumerId = message.TargetConsumerId;
dispatch.Destination = message.Destination;
return dispatch;
}
protected virtual void WriteConnectionInfo(ConnectionInfo command, StompFrameStream ss)
{
// lets force a receipt
command.ResponseRequired = true;
ss.WriteCommand(command, "CONNECT");
ss.WriteHeader("client-id", command.ClientId);
ss.WriteHeader("login", command.UserName);
ss.WriteHeader("passcode", command.Password);
if (command.ResponseRequired)
{
ss.WriteHeader("request-id", command.CommandId);
}
ss.Flush();
}
protected virtual void WriteShutdownInfo(ShutdownInfo command, StompFrameStream ss)
{
ss.WriteCommand(command, "DISCONNECT");
System.Diagnostics.Debug.Assert(!command.ResponseRequired);
ss.Flush();
}
protected virtual void WriteConsumerInfo(ConsumerInfo command, StompFrameStream ss)
{
ss.WriteCommand(command, "SUBSCRIBE");
ss.WriteHeader("destination", StompHelper.ToStomp(command.Destination));
ss.WriteHeader("id", StompHelper.ToStomp(command.ConsumerId));
ss.WriteHeader("durable-subscriber-name", command.SubscriptionName);
ss.WriteHeader("selector", command.Selector);
if ( command.NoLocal )
ss.WriteHeader("no-local", command.NoLocal);
ss.WriteHeader("ack", "client");
// ActiveMQ extensions to STOMP
ss.WriteHeader("activemq.dispatchAsync", command.DispatchAsync);
if ( command.Exclusive )
ss.WriteHeader("activemq.exclusive", command.Exclusive);
if( command.SubscriptionName != null )
{
ss.WriteHeader("activemq.subscriptionName", command.SubscriptionName);
// For an older 4.0 broker we need to set this header so they get the
// subscription as wel..
ss.WriteHeader("activemq.subcriptionName", command.SubscriptionName);
}
ss.WriteHeader("activemq.maximumPendingMessageLimit", command.MaximumPendingMessageLimit);
ss.WriteHeader("activemq.prefetchSize", command.PrefetchSize);
ss.WriteHeader("activemq.priority", command.Priority);
if ( command.Retroactive )
ss.WriteHeader("activemq.retroactive", command.Retroactive);
consumers[command.ConsumerId] = command.ConsumerId;
ss.Flush();
}
protected virtual void WriteRemoveInfo(RemoveInfo command, StompFrameStream ss)
{
object id = command.ObjectId;
if (id is ConsumerId)
{
ConsumerId consumerId = id as ConsumerId;
ss.WriteCommand(command, "UNSUBSCRIBE");
ss.WriteHeader("id", StompHelper.ToStomp(consumerId));
ss.Flush();
consumers.Remove(consumerId);
}
else if (id is SessionId)
{
// When a session is removed, it needs to remove it's consumers too.
// Find all the consumer that were part of the session.
SessionId sessionId = (SessionId) id;
ArrayList matches = new ArrayList();
foreach (DictionaryEntry entry in consumers)
{
ConsumerId t = (ConsumerId) entry.Key;
if( sessionId.ConnectionId==t.ConnectionId && sessionId.Value==t.SessionId )
{
matches.Add(t);
}
}
bool unsubscribedConsumer = false;
// Un-subscribe them.
foreach (ConsumerId consumerId in matches)
{
ss.WriteCommand(command, "UNSUBSCRIBE");
ss.WriteHeader("id", StompHelper.ToStomp(consumerId));
ss.Flush();
consumers.Remove(consumerId);
unsubscribedConsumer = true;
}
if(!unsubscribedConsumer && command.ResponseRequired)
{
ss.WriteCommand(command, "UNSUBSCRIBE", true);
ss.WriteHeader("id", sessionId);
ss.Flush();
}
}
else if(id is ProducerId)
{
if(command.ResponseRequired)
{
ss.WriteCommand(command, "UNSUBSCRIBE", true);
ss.WriteHeader("id", id);
ss.Flush();
}
}
else if(id is ConnectionId)
{
if(command.ResponseRequired)
{
ss.WriteCommand(command, "UNSUBSCRIBE", true);
ss.WriteHeader("id", id);
ss.Flush();
}
}
}
protected virtual void WriteTransactionInfo(TransactionInfo command, StompFrameStream ss)
{
TransactionId id = command.TransactionId;
if (id is LocalTransactionId)
{
string type = "BEGIN";
TransactionType transactionType = (TransactionType) command.Type;
switch (transactionType)
{
case TransactionType.CommitOnePhase:
command.ResponseRequired = true;
type = "COMMIT";
break;
case TransactionType.Rollback:
command.ResponseRequired = true;
type = "ABORT";
break;
}
Tracer.Debug(">>> For transaction type: " + transactionType + " we are using command type: " + type);
ss.WriteCommand(command, type);
ss.WriteHeader("transaction", StompHelper.ToStomp(id));
ss.Flush();
}
}
protected virtual void WriteMessage(ActiveMQMessage command, StompFrameStream ss)
{
ss.WriteCommand(command, "SEND");
ss.WriteHeader("destination", StompHelper.ToStomp(command.Destination));
if (command.ReplyTo != null)
ss.WriteHeader("reply-to", StompHelper.ToStomp(command.ReplyTo));
if (command.CorrelationId != null )
ss.WriteHeader("correlation-id", command.CorrelationId);
if (command.Expiration != 0)
ss.WriteHeader("expires", command.Expiration);
if (command.Priority != 4)
ss.WriteHeader("priority", command.Priority);
if (command.Type != null)
ss.WriteHeader("type", command.Type);
if (command.TransactionId!=null)
ss.WriteHeader("transaction", StompHelper.ToStomp(command.TransactionId));
ss.WriteHeader("persistent", command.Persistent);
// lets force the content to be marshalled
command.BeforeMarshall(null);
if (command is ActiveMQTextMessage)
{
ActiveMQTextMessage textMessage = command as ActiveMQTextMessage;
ss.Content = encoding.GetBytes(textMessage.Text);
}
else
{
ss.Content = command.Content;
if(null != command.Content)
{
ss.ContentLength = command.Content.Length;
}
else
{
ss.ContentLength = 0;
}
}
IPrimitiveMap map = command.Properties;
foreach (string key in map.Keys)
{
ss.WriteHeader(key, map[key]);
}
ss.Flush();
}
protected virtual void WriteMessageAck(MessageAck command, StompFrameStream ss)
{
ss.WriteCommand(command, "ACK", true);
// TODO handle bulk ACKs?
ss.WriteHeader("message-id", StompHelper.ToStomp(command.LastMessageId));
if(command.TransactionId != null)
{
ss.WriteHeader("transaction", StompHelper.ToStomp(command.TransactionId));
}
ss.Flush();
}
protected virtual void SendCommand(Command command)
{
if (transport == null)
{
Tracer.Fatal("No transport configured so cannot return command: " + command);
}
else
{
transport.Command(transport, command);
}
}
protected virtual string RemoveHeader(IDictionary headers, string name)
{
object value = headers[name];
if (value == null)
{
return null;
}
else
{
headers.Remove(name);
return value.ToString();
}
}
protected virtual string ToString(object value)
{
if (value != null)
{
return value.ToString();
}
else
{
return null;
}
}
}
}