fix for: https://issues.apache.org/jira/browse/AMQNET-362
diff --git a/src/main/csharp/State/ConnectionStateTracker.cs b/src/main/csharp/State/ConnectionStateTracker.cs
index 352001e..62b0974 100644
--- a/src/main/csharp/State/ConnectionStateTracker.cs
+++ b/src/main/csharp/State/ConnectionStateTracker.cs
@@ -42,8 +42,8 @@
private bool _trackMessages = true;
private int _maxCacheSize = 256;
private int currentCacheSize;
- private readonly Dictionary<MessageId, Message> messageCache = new Dictionary<MessageId, Message>();
- private readonly Queue<MessageId> messageCacheFIFO = new Queue<MessageId>();
+ private readonly Dictionary<Object, Command> messageCache = new Dictionary<Object, Command>();
+ private readonly Queue<Object> messageCacheFIFO = new Queue<Object>();
protected void RemoveEldestInCache()
{
@@ -102,10 +102,17 @@
public void TrackBack(Command command)
{
- if(TrackMessages && command != null && command.IsMessage)
+ if (command != null)
{
- Message message = (Message) command;
- if(message.TransactionId == null)
+ if (TrackMessages && command.IsMessage)
+ {
+ Message message = (Message) command;
+ if(message.TransactionId == null)
+ {
+ currentCacheSize = currentCacheSize + 1;
+ }
+ }
+ else if (command.IsMessagePull)
{
currentCacheSize = currentCacheSize + 1;
}
@@ -119,6 +126,10 @@
{
ConnectionInfo info = connectionState.Info;
info.FailoverReconnect = true;
+ if (Tracer.IsDebugEnabled)
+ {
+ Tracer.Debug("conn: " + connectionState.Info.ConnectionId);
+ }
transport.Oneway(info);
DoRestoreTempDestinations(transport, connectionState);
@@ -133,19 +144,46 @@
DoRestoreTransactions(transport, connectionState);
}
}
- //now flush messages
- foreach(Message msg in messageCache.Values)
+
+ // Now flush messages
+ foreach(Command command in messageCache.Values)
{
- transport.Oneway(msg);
+ if (Tracer.IsDebugEnabled)
+ {
+ Tracer.Debug("Replaying command: " + command);
+ }
+
+ transport.Oneway(command);
}
}
private void DoRestoreTransactions(ITransport transport, ConnectionState connectionState)
{
AtomicCollection<TransactionState> transactionStates = connectionState.TransactionStates;
+ List<TransactionInfo> toRollback = new List<TransactionInfo>();
foreach(TransactionState transactionState in transactionStates)
{
+ // rollback any completed transactions - no way to know if commit got there
+ // or if reply went missing
+ if (transactionState.Commands.Count != 0)
+ {
+ Command lastCommand = transactionState.Commands[transactionState.Commands.Count - 1];
+ if (lastCommand.IsTransactionInfo)
+ {
+ TransactionInfo transactionInfo = lastCommand as TransactionInfo;
+ if (transactionInfo.Type == TransactionInfo.COMMIT_ONE_PHASE)
+ {
+ if (Tracer.IsDebugEnabled)
+ {
+ Tracer.Debug("rolling back potentially completed tx: " + transactionState.getId());
+ }
+ toRollback.Add(transactionInfo);
+ continue;
+ }
+ }
+ }
+
// replay the add and remove of short lived producers that may have been
// involved in the transaction
foreach (ProducerState producerState in transactionState.ProducerStates)
@@ -178,6 +216,18 @@
transport.Oneway(producerRemove);
}
}
+
+ foreach (TransactionInfo command in toRollback)
+ {
+ // respond to the outstanding commit
+ ExceptionResponse response = new ExceptionResponse();
+ response.Exception = new BrokerError();
+ response.Exception.Message =
+ "Transaction completion in doubt due to failover. Forcing rollback of " + command.TransactionId;
+ response.Exception.ExceptionClass = (new TransactionRolledBackException()).GetType().FullName;
+ response.CorrelationId = command.CommandId;
+ transport.Command(transport, response);
+ }
}
/// <summary>
@@ -189,6 +239,10 @@
// Restore the connection's sessions
foreach(SessionState sessionState in connectionState.SessionStates)
{
+ if (Tracer.IsDebugEnabled)
+ {
+ Tracer.Debug("Restoring session: " + sessionState.Info.SessionId);
+ }
transport.Oneway(sessionState.Info);
if(RestoreProducers)
@@ -262,6 +316,10 @@
// Restore the session's producers
foreach(ProducerState producerState in sessionState.ProducerStates)
{
+ if (Tracer.IsDebugEnabled)
+ {
+ Tracer.Debug("Restoring producer: " + producerState.Info.ProducerId);
+ }
transport.Oneway(producerState.Info);
}
}
@@ -662,6 +720,17 @@
return null;
}
+ public override Response processMessagePull(MessagePull pull)
+ {
+ if (pull != null)
+ {
+ // leave a single instance in the cache
+ String id = pull.Destination + "::" + pull.ConsumerId;
+ messageCache.Add(id, pull);
+ }
+ return null;
+ }
+
public bool RestoreConsumers
{
get { return _restoreConsumers; }
diff --git a/src/test/csharp/Transport/failover/FailoverTransactionTest.cs b/src/test/csharp/Transport/failover/FailoverTransactionTest.cs
index 203aedf..cacfddb 100644
--- a/src/test/csharp/Transport/failover/FailoverTransactionTest.cs
+++ b/src/test/csharp/Transport/failover/FailoverTransactionTest.cs
@@ -39,6 +39,85 @@
private readonly int MSG_COUNT = 2;
private readonly String destinationName = "FailoverTransactionTestQ";
+ [SetUp]
+ public override void SetUp()
+ {
+ base.SetUp();
+
+ this.connection = null;
+ this.interrupted = false;
+ this.resumed = false;
+ this.commitFailed = false;
+ }
+
+ [Test]
+ public void FailoverAfterCommitSentTest()
+ {
+ string uri = "failover:(tcpfaulty://${activemqhost}:61616?transport.useLogging=true)";
+ IConnectionFactory factory = new ConnectionFactory(NMSTestSupport.ReplaceEnvVar(uri));
+ using(connection = factory.CreateConnection() as Connection)
+ {
+ connection.ConnectionInterruptedListener +=
+ new ConnectionInterruptedListener(TransportInterrupted);
+ connection.ConnectionResumedListener +=
+ new ConnectionResumedListener(TransportResumed);
+
+ connection.Start();
+
+ ITransport transport = (connection as Connection).ITransport;
+ TcpFaultyTransport tcpFaulty = transport.Narrow(typeof(TcpFaultyTransport)) as TcpFaultyTransport;
+ Assert.IsNotNull(tcpFaulty);
+ tcpFaulty.OnewayCommandPostProcessor += this.FailOnCommitTransportHook;
+
+ using(ISession session = connection.CreateSession())
+ {
+ IDestination destination = session.GetQueue(destinationName);
+ PurgeQueue(connection, destination);
+ }
+
+ Tracer.Debug("Test is putting " + MSG_COUNT + " messages on the queue: " + destinationName);
+
+ using(ISession session = connection.CreateSession(AcknowledgementMode.Transactional))
+ {
+ IDestination destination = session.GetQueue(destinationName);
+ PutMsgIntoQueue(session, destination, false);
+
+ try
+ {
+ session.Commit();
+ Assert.Fail("Should have thrown a TransactionRolledBackException");
+ }
+ catch(TransactionRolledBackException)
+ {
+ }
+ catch
+ {
+ Assert.Fail("Should have thrown a TransactionRolledBackException");
+ }
+ }
+
+ Assert.IsTrue(this.interrupted);
+ Assert.IsTrue(this.resumed);
+
+ Tracer.Debug("Test is attempting to read " + MSG_COUNT +
+ " messages from the queue: " + destinationName);
+
+ using(ISession session = connection.CreateSession())
+ {
+ IDestination destination = session.GetQueue(destinationName);
+ IMessageConsumer consumer = session.CreateConsumer(destination);
+ for (int i = 0; i < MSG_COUNT; ++i)
+ {
+ IMessage msg = consumer.Receive(TimeSpan.FromSeconds(5));
+ Assert.IsNotNull(msg, "Should receive message[" + (i + 1) + "] after commit failed once.");
+ }
+ }
+ }
+
+ Assert.IsTrue(this.interrupted);
+ Assert.IsTrue(this.resumed);
+ }
+
[Test]
public void FailoverBeforeCommitSentTest()
{
@@ -69,24 +148,34 @@
using(ISession session = connection.CreateSession(AcknowledgementMode.Transactional))
{
IDestination destination = session.GetQueue(destinationName);
- PutMsgIntoQueue(session, destination);
+ PutMsgIntoQueue(session, destination, false);
+
+ try
+ {
+ session.Commit();
+ Assert.Fail("Should have thrown a TransactionRolledBackException");
+ }
+ catch(TransactionRolledBackException)
+ {
+ }
+ catch
+ {
+ Assert.Fail("Should have thrown a TransactionRolledBackException");
+ }
}
Assert.IsTrue(this.interrupted);
Assert.IsTrue(this.resumed);
- Tracer.Debug("Test is attempting to read " + MSG_COUNT +
- " messages from the queue: " + destinationName);
+ Tracer.Debug("Test is attempting to read a message from" +
+ destinationName + " but no messages are expected");
using(ISession session = connection.CreateSession())
{
IDestination destination = session.GetQueue(destinationName);
IMessageConsumer consumer = session.CreateConsumer(destination);
- for (int i = 0; i < MSG_COUNT; ++i)
- {
- IMessage msg = consumer.Receive(TimeSpan.FromSeconds(5));
- Assert.IsNotNull(msg, "Should receive message[" + (i + 1) + "] after commit failed once.");
- }
+ IMessage msg = consumer.Receive(TimeSpan.FromSeconds(5));
+ Assert.IsNull(msg, "Should not receive a message after commit failed.");
}
}
@@ -111,7 +200,6 @@
ITransport transport = (connection as Connection).ITransport;
TcpFaultyTransport tcpFaulty = transport.Narrow(typeof(TcpFaultyTransport)) as TcpFaultyTransport;
Assert.IsNotNull(tcpFaulty);
- tcpFaulty.OnewayCommandPreProcessor += this.FailOnCommitTransportHook;
using(ISession session = connection.CreateSession())
{
@@ -125,6 +213,8 @@
{
IDestination destination = session.GetQueue(destinationName);
PutMsgIntoQueue(session, destination, false);
+ tcpFaulty.Close();
+ PutMsgIntoQueue(session, destination, false);
session.Commit();
}