More work on client side of connection processing
diff --git a/src/Proton.Client/Client/Implementation/ClientConnection.cs b/src/Proton.Client/Client/Implementation/ClientConnection.cs
index 2c05ebc..bb223c1 100644
--- a/src/Proton.Client/Client/Implementation/ClientConnection.cs
+++ b/src/Proton.Client/Client/Implementation/ClientConnection.cs
@@ -29,7 +29,11 @@
namespace Apache.Qpid.Proton.Client.Implementation
{
- // TODO
+ /// <summary>
+ /// The client connection class manages a single connection to a remote AMQP
+ /// peer and handles connection errors and reconnection operations if those
+ /// are enabled.
+ /// </summary>
public class ClientConnection : IConnection
{
private static IProtonLogger LOG = ProtonLoggerFactory.GetLogger<ClientConnection>();
@@ -109,16 +113,17 @@
{
try
{
- CloseAsync(error).GetAwaiter().GetResult();
+ DoCloseAsync(error).GetAwaiter().GetResult();
}
catch (Exception)
{
+ // Ignore any exception as we are closed regardless
}
}
public Task<IConnection> CloseAsync(IErrorCondition error = null)
{
- throw new System.NotImplementedException();
+ return DoCloseAsync(error);
}
public void Dispose()
@@ -671,6 +676,70 @@
#region private connection utility methods
+ private Task<IConnection> DoCloseAsync(IErrorCondition error)
+ {
+ if (closed.CompareAndSet(false, true))
+ {
+ try
+ {
+ ioContext.EventLoop.Execute(() =>
+ {
+ LOG.Trace("Close requested for connection: {0}", this);
+
+ if (protonConnection.IsLocallyOpen)
+ {
+ protonConnection.ErrorCondition = ClientErrorCondition.AsProtonErrorCondition(error);
+
+ try
+ {
+ protonConnection.Close();
+ }
+ catch (Exception)
+ {
+ // Engine error handler will kick in if the write of Close fails
+ }
+ }
+ else
+ {
+ engine.Shutdown();
+ }
+ });
+ }
+ catch (RejectedExecutionException rje)
+ {
+ LOG.Trace("Close task rejected from the event loop", rje);
+ }
+ finally
+ {
+ try
+ {
+ // TODO: Blocking here isn't ideal but for now we want to await
+ /// the remote sending the close performative back to us
+ /// before dropping the connection. We should probably schedule
+ /// a task that closes the connection and completes the close
+ /// future if the remote hasn't responded by then.
+ closeFuture.Task.GetAwaiter().GetResult();
+ }
+ catch (Exception)
+ {
+ // Ignore error as we are closed regardless
+ }
+ finally
+ {
+ try
+ {
+ transport.Close();
+ }
+ catch (Exception) { }
+
+ ioContext.Shutdown();
+ }
+ }
+ }
+
+ return closeFuture.Task;
+ }
+
private void InitializeProtonResources(ReconnectLocation location)
{
if (options.SaslOptions.SaslEnabled)
diff --git a/src/Proton.Client/Client/Implementation/ClientInstance.cs b/src/Proton.Client/Client/Implementation/ClientInstance.cs
index b824004..73bbd04 100644
--- a/src/Proton.Client/Client/Implementation/ClientInstance.cs
+++ b/src/Proton.Client/Client/Implementation/ClientInstance.cs
@@ -53,10 +53,11 @@
{
try
{
- CloseAsync().Wait();
+ CloseAsync().GetAwaiter().GetResult();
}
catch (Exception)
{
+ // Ignore exceptions as we are closed regardless.
}
}
diff --git a/src/Proton.Client/Client/Transport/TcpTransport.cs b/src/Proton.Client/Client/Transport/TcpTransport.cs
index bc00488..ce0ce5e 100644
--- a/src/Proton.Client/Client/Transport/TcpTransport.cs
+++ b/src/Proton.Client/Client/Transport/TcpTransport.cs
@@ -48,7 +48,7 @@
private EndPoint channelEndpoint;
private Task readLoop;
private Task writeLoop;
- private TcpClient channel;
+ private Socket channel;
private Stream socketReader;
private Stream socketWriter;
private volatile bool connected;
@@ -105,7 +105,21 @@
{
}
- channel?.Close();
+ try
+ {
+ channel?.Shutdown(SocketShutdown.Both);
+ }
+ catch (Exception)
+ {
+ }
+
+ try
+ {
+ channel?.Close();
+ }
+ catch (Exception)
+ {
+ }
}
}
@@ -119,7 +133,7 @@
IPAddress address = Dns.GetHostEntry(host).AddressList[0];
- channel = new TcpClient(address.AddressFamily);
+ channel = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
channel.BeginConnect(address, port, new AsyncCallback(ConnectCallback), this);
return this;
@@ -170,8 +184,8 @@
private void CompleteConnection()
{
- socketReader = channel.GetStream();
- socketWriter = channel.GetStream();
+ socketReader = new NetworkStream(channel);
+ socketWriter = new NetworkStream(channel);
// TODO: This currently creates two threads for each transport which could
// be reduced to one or none at some point using async Tasks
@@ -210,6 +224,13 @@
if (bytesRead == 0)
{
_ = channelOutputSource.TryComplete();
+ try
+ {
+ channel.Shutdown(SocketShutdown.Both);
+ }
+ catch(Exception)
+ {}
+
// End of stream
// TODO mark as disconnected to fail writes
if (!closed)
@@ -242,15 +263,19 @@
{
private readonly IProtonBuffer buffer;
private readonly Action completion;
+ private readonly bool flush;
- public ChannelWrite(IProtonBuffer buffer, Action completion)
+ public ChannelWrite(IProtonBuffer buffer, Action completion, bool flush = true)
{
this.buffer = buffer;
this.completion = completion;
+ this.flush = flush;
}
public IProtonBuffer Buffer => buffer;
+ public bool IsFlushRequired => flush;
+
public bool HasCompletion => completion != null;
public Action Completion => completion;
@@ -262,7 +287,9 @@
{
while (await channelOutputSink.WaitToReadAsync().ConfigureAwait(false))
{
- while (channelOutputSink.TryRead(out ChannelWrite write))
+ ChannelWrite write = null;
+
+ while (channelOutputSink.TryRead(out write))
{
write.Buffer.ForEachReadableComponent(0, (idx, x) => WriteComponent(x));
@@ -274,7 +301,10 @@
}
}
- await socketWriter.FlushAsync().ConfigureAwait(false);
+ if (write?.IsFlushRequired ?? false)
+ {
+ await socketWriter.FlushAsync().ConfigureAwait(false);
+ }
}
}
diff --git a/test/Proton.TestPeer.Tests/Driver/ProtonBaseTestFixture.cs b/test/Proton.TestPeer.Tests/Driver/ProtonBaseTestFixture.cs
index c5ec717..4b8916e 100644
--- a/test/Proton.TestPeer.Tests/Driver/ProtonBaseTestFixture.cs
+++ b/test/Proton.TestPeer.Tests/Driver/ProtonBaseTestFixture.cs
@@ -47,13 +47,11 @@
config.AddRule(NLog.LogLevel.Debug, NLog.LogLevel.Fatal, logconsole);
config.AddRule(NLog.LogLevel.Trace, NLog.LogLevel.Fatal, logfile);
- loggerFactory = NullLoggerFactory.Instance;
- // loggerFactory = LoggerFactory.Create(builder =>
- // builder.ClearProviders().SetMinimumLevel(LogLevel.Trace).AddNLog(config)
- // );
+ loggerFactory = LoggerFactory.Create(builder =>
+ builder.ClearProviders().SetMinimumLevel(LogLevel.Trace).AddNLog(config)
+ );
- logger = NullLogger.Instance;
- //loggerFactory.CreateLogger(GetType().Name);
+ logger = loggerFactory.CreateLogger(GetType().Name);
AppDomain currentDomain = AppDomain.CurrentDomain;
currentDomain.UnhandledException += new UnhandledExceptionEventHandler(UncaughtExceptionHander);