blob: 61893fceb0e85efbd527c17686b70861bfda9d9a [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
using System;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
namespace Org.Apache.REEF.Wake.Remote.Impl
{
/// <summary>
/// Performs low level network IO operations between hosts
/// </summary>
public sealed class Channel
{
private readonly NetworkStream _stream;
/// <summary>
/// Constructs a new Channel with the the connected NetworkStream.
/// </summary>
/// <param name="stream">The connected stream</param>
public Channel(NetworkStream stream)
{
if (stream == null)
{
throw new ArgumentNullException("stream");
}
_stream = stream;
}
/// <summary>
/// Sends a message to the connected client synchronously
/// </summary>
/// <param name="message">The message to send</param>
public void Write(byte[] message)
{
if (message == null)
{
throw new ArgumentNullException("message");
}
byte[] messageBuffer = GenerateMessageBuffer(message);
_stream.Write(messageBuffer, 0, messageBuffer.Length);
}
/// <summary>
/// Sends a message to the connected client asynchronously
/// </summary>
/// <param name="message">The message to send</param>
/// <param name="token">The cancellation token</param>
/// <returns>The awaitable write task</returns>
public async Task WriteAsync(byte[] message, CancellationToken token)
{
byte[] messageBuffer = GenerateMessageBuffer(message);
await _stream.WriteAsync(messageBuffer, 0, messageBuffer.Length, token);
}
/// <summary>
/// Reads an incoming message as a byte array synchronously.
/// The message length is read as the first four bytes.
/// </summary>
/// <returns>The byte array message</returns>
public byte[] Read()
{
int payloadLength = ReadMessageLength();
if (payloadLength == 0)
{
return null;
}
return ReadBytes(payloadLength);
}
/// <summary>
/// Reads an incoming message as a byte array asynchronously.
/// The message length is read as the first four bytes.
/// </summary>
/// <param name="token">The cancellation token</param>
/// <returns>The byte array message</returns>
public async Task<byte[]> ReadAsync(CancellationToken token)
{
int payloadLength = await GetMessageLengthAsync(token);
if (payloadLength == 0)
{
return null;
}
return await ReadBytesAsync(payloadLength, token);
}
/// <summary>
/// Helper method to read the specified number of bytes from the network stream.
/// </summary>
/// <param name="bytesToRead">The number of bytes to read</param>
/// <returns>The byte[] read from the network stream with the requested
/// number of bytes, otherwise null if the operation failed.
/// </returns>
private byte[] ReadBytes(int bytesToRead)
{
int totalBytesRead = 0;
byte[] buffer = new byte[bytesToRead];
while (totalBytesRead < bytesToRead)
{
int bytesRead = _stream.Read(buffer, totalBytesRead, bytesToRead - totalBytesRead);
if (bytesRead == 0)
{
// Read timed out or connection was closed
return null;
}
totalBytesRead += bytesRead;
}
return buffer;
}
/// <summary>
/// Helper method to read the specified number of bytes from the network stream.
/// </summary>
/// <param name="bytesToRead">The number of bytes to read</param>
/// <param name="token">The cancellation token</param>
/// <returns>The byte[] read from the network stream with the requested
/// number of bytes, otherwise null if the operation failed.
/// </returns>
private async Task<byte[]> ReadBytesAsync(int bytesToRead, CancellationToken token)
{
int bytesRead = 0;
byte[] buffer = new byte[bytesToRead];
while (bytesRead < bytesToRead)
{
int amountRead = await _stream.ReadAsync(buffer, bytesRead, bytesToRead - bytesRead, token);
if (amountRead == 0)
{
// Read timed out or connection was closed
return null;
}
bytesRead += amountRead;
}
return buffer;
}
/// <summary>
/// Generates the payload buffer containing the message along
/// with a header indicating the message length.
/// </summary>
/// <param name="message">The message to send</param>
/// <returns>The payload buffer</returns>
private byte[] GenerateMessageBuffer(byte[] message)
{
byte[] lengthBuffer1 = BitConverter.GetBytes(message.Length + 4);
byte[] lengthBuffer2 = BitConverter.GetBytes(message.Length);
if (BitConverter.IsLittleEndian)
{
Array.Reverse(lengthBuffer1);
}
int len = lengthBuffer1.Length + lengthBuffer2.Length + message.Length;
byte[] messageBuffer = new byte[len];
int bytesCopied = 0;
bytesCopied += CopyBytes(lengthBuffer1, messageBuffer, 0);
bytesCopied += CopyBytes(lengthBuffer2, messageBuffer, bytesCopied);
CopyBytes(message, messageBuffer, bytesCopied);
return messageBuffer;
}
/// <summary>
/// Reads the first four bytes from the stream and decode
/// it to get the message length in bytes
/// </summary>
/// <returns>The incoming message's length in bytes</returns>
private int ReadMessageLength()
{
byte[] lenBytes = ReadBytes(sizeof(int));
if (lenBytes == null)
{
return 0;
}
if (BitConverter.IsLittleEndian)
{
Array.Reverse(lenBytes);
}
if (BitConverter.ToInt32(lenBytes, 0) == 0)
{
return 0;
}
byte[] msgLength = ReadBytes(sizeof(int));
return (msgLength == null) ? 0 : BitConverter.ToInt32(msgLength, 0);
}
/// <summary>
/// Reads the first four bytes from the stream and decode
/// it to get the message length in bytes
/// </summary>
/// <param name="token">The cancellation token</param>
/// <returns>The incoming message's length in bytes</returns>
private async Task<int> GetMessageLengthAsync(CancellationToken token)
{
byte[] lenBytes = await ReadBytesAsync(sizeof(int), token);
if (lenBytes == null)
{
return 0;
}
if (BitConverter.IsLittleEndian)
{
Array.Reverse(lenBytes);
}
if (BitConverter.ToInt32(lenBytes, 0) == 0)
{
return 0;
}
byte[] msgLength = ReadBytes(sizeof(int));
return (msgLength == null) ? 0 : BitConverter.ToInt32(msgLength, 0);
}
/// <summary>
/// Copies the entire source buffer into the destination buffer the specified
/// destination offset.
/// </summary>
/// <param name="source">The source buffer to be copied</param>
/// <param name="dest">The destination buffer to copy to</param>
/// <param name="destOffset">The offset at the destination buffer to begin
/// copying.</param>
/// <returns>The number of bytes copied</returns>
private int CopyBytes(byte[] source, byte[] dest, int destOffset)
{
Buffer.BlockCopy(source, 0, dest, destOffset, source.Length);
return source.Length;
}
}
}