blob: dff30b118c60dcab3f46eb404f3798463923ba6a [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
using System;
using System.Collections.Generic;
using System.IO;
using System.Reflection;
using Microsoft.Hadoop.Avro;
using Org.Apache.REEF.Utilities.Logging;
using org.apache.reef.wake.avro.message;
namespace Org.Apache.REEF.Wake.Avro
{
/// <summary>
/// Given a namespace of Avro messages which represent a protocol, the ProtocolSerailizer
/// class reflects all of the message classes and builds Avro seriailziers and deserializers.
/// Avro messages are then serialized/deserilaized to and from byte[] arrays using the
/// Read/Write methods. A transport such as a RemoteObserver using a ByteCodec can then
/// be used to send and receive the serialized messages.
/// </summary>
public sealed class ProtocolSerializer
{
private static readonly Logger Logr = Logger.GetLogger(typeof(ProtocolSerializer));
// Delagates for message serializers and deserializers.
private delegate void Serialize(MemoryStream stream, object message);
private delegate void Deserialize(MemoryStream stream, object observer, long sequence);
// Message type to serialize/derserialize delagate.
private readonly SortedDictionary<string, Serialize> serializeMap = new SortedDictionary<string, Serialize>();
private readonly SortedDictionary<string, Deserialize> deserializeMap = new SortedDictionary<string, Deserialize>();
private readonly IAvroSerializer<Header> headerSerializer = AvroSerializer.Create<Header>();
/// <summary>
/// Register all of the protocol messages using reflection.
/// </summary>
/// <param name="assembly">The Assembley object which contains the namespace of the message classes.</param>
/// <param name="messageNamespace">A string which contains the namespace the protocol messages.</param>
public ProtocolSerializer(Assembly assembly, string messageNamespace)
{
MethodInfo registerInfo = typeof(ProtocolSerializer).GetMethod("Register", BindingFlags.Instance | BindingFlags.NonPublic);
MethodInfo genericInfo;
Logr.Log(Level.Info, "Retrieving types for assembly: {0}", assembly.FullName);
List<Type> types = new List<Type>(assembly.GetTypes());
types.Add(typeof(Header));
foreach (Type type in types)
{
string name = type.FullName;
if (name.StartsWith(messageNamespace))
{
genericInfo = registerInfo.MakeGenericMethod(new[] { type });
genericInfo.Invoke(this, null);
}
}
}
/// <summary>
/// Generate and store the metadata necessary to serialze and deserialize a specific message type.
/// </summary>
/// <typeparam name="TMessage">The class type of the message being registered.</typeparam>
internal void Register<TMessage>()
{
Logr.Log(Level.Info, "Registering message type: {0} {1}", typeof(TMessage).FullName, typeof(TMessage).Name);
IAvroSerializer<TMessage> messageSerializer = AvroSerializer.Create<TMessage>();
Serialize serialize = (MemoryStream stream, object message) =>
{
messageSerializer.Serialize(stream, (TMessage)message);
};
serializeMap.Add(typeof(TMessage).Name, serialize);
Deserialize deserialize = (MemoryStream stream, object observer, long sequence) =>
{
TMessage message = messageSerializer.Deserialize(stream);
IObserver<MessageInstance<TMessage>> msgObserver = observer as IObserver<MessageInstance<TMessage>>;
if (msgObserver != null)
{
msgObserver.OnNext(new MessageInstance<TMessage>(sequence, message));
}
else
{
Logr.Log(Level.Warning, "Unhandled message received: {0}", message);
}
};
deserializeMap.Add(typeof(TMessage).Name, deserialize);
}
/// <summary>
/// Serialize the input message and return a byte array.
/// </summary>
/// <param name="message">An object reference to the messeage to be serialized</param>
/// <param name="sequence">A long which cotains the higher level protocols sequence number for the message.</param>
/// <returns>A byte array containing the serialized header and message.</returns>
public byte[] Write(object message, long sequence)
{
string name = message.GetType().Name;
Logr.Log(Level.Info, "Serializing message: {0}", name);
try
{
using (MemoryStream stream = new MemoryStream())
{
Header header = new Header(sequence, name);
headerSerializer.Serialize(stream, header);
Serialize serialize;
if (serializeMap.TryGetValue(name, out serialize))
{
serialize(stream, message);
}
else
{
throw new SeializationException("Request to serialize unknown message type: " + name);
}
return stream.GetBuffer();
}
}
catch (Exception e)
{
Logr.Log(Level.Error, "Failure writing message.", e);
throw e;
}
}
/// <summary>
/// Read a message from the input byte array.
/// </summary>
/// <param name="data">The byte array containing a header message and message to be deserialized.</param>
/// <param name="observer">An object which implements the IObserver<> interface for the message being deserialized.</param>
public void Read(byte[] data, object observer)
{
try
{
using (MemoryStream stream = new MemoryStream(data))
{
Header head = headerSerializer.Deserialize(stream);
Deserialize deserialize;
if (deserializeMap.TryGetValue(head.className, out deserialize))
{
deserialize(stream, observer, head.sequence);
}
else
{
throw new SeializationException("Request to deserialize unknown message type: " + head.className);
}
}
}
catch (Exception e)
{
Logr.Log(Level.Error, "Failure reading message.", e);
throw e;
}
}
}
}