[REEF-1906] Make ProtocolSerializer java class injectable by Tang
Summary of changes:
* Add `@Inject` annotation to `ProtocolSerializer` constructor
* Create a named class `ProtocolSerializerNamespace`
* Implement static `.getClassId()` method to get unified class ID available to all serialization code
* Use injection in the `ProtocolSerializerTest` unit tests
* Minor refactoring and logging improvements in message serialization code and around
JIRA: [REEF-1906](https://issues.apache.org/jira/browse/REEF-1906)
Pull request:
This closes #1395
diff --git a/lang/java/reef-wake/wake/src/main/java/org/apache/reef/wake/avro/ProtocolSerializer.java b/lang/java/reef-wake/wake/src/main/java/org/apache/reef/wake/avro/ProtocolSerializer.java
index ad10d5a..a0c5cff 100644
--- a/lang/java/reef-wake/wake/src/main/java/org/apache/reef/wake/avro/ProtocolSerializer.java
+++ b/lang/java/reef-wake/wake/src/main/java/org/apache/reef/wake/avro/ProtocolSerializer.java
@@ -21,12 +21,15 @@
import org.apache.avro.specific.SpecificDatumReader;
import org.apache.avro.specific.SpecificRecord;
import org.apache.avro.specific.SpecificRecordBase;
+import org.apache.reef.tang.annotations.Parameter;
import org.apache.reef.wake.MultiObserver;
import org.apache.reef.wake.avro.message.Header;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
+import java.io.IOException;
import java.io.InputStream;
+import java.lang.reflect.InvocationTargetException;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
@@ -35,6 +38,8 @@
import io.github.lukehutch.fastclasspathscanner.FastClasspathScanner;
import io.github.lukehutch.fastclasspathscanner.scanner.ScanResult;
+import javax.inject.Inject;
+
/**
* The ProtocolSerializer generates serializers and deserializers for
* all of the Avro messages contained in a specified package. The name
@@ -43,10 +48,13 @@
* would sit in the org.foo.me.messages package.
*/
public final class ProtocolSerializer {
+
private static final Logger LOG = Logger.getLogger(ProtocolSerializer.class.getName());
+
// Maps for mapping message class names to serializer and deserializer classes.
private final Map<String, IMessageSerializer> nameToSerializerMap = new HashMap<>();
private final Map<String, IMessageDeserializer> nameToDeserializerMap = new HashMap<>();
+
private final SpecificDatumReader<Header> headerReader = new SpecificDatumReader<>(Header.class);
/**
@@ -54,7 +62,10 @@
* @param messagePackage A string which contains the full name of the
* package containing the protocol messages.
*/
- public ProtocolSerializer(final String messagePackage) {
+ @Inject
+ private ProtocolSerializer(
+ @Parameter(ProtocolSerializerNamespace.class) final String messagePackage) {
+
// Build a list of the message reflection classes.
final ScanResult scanResult = new FastClasspathScanner(messagePackage).scan();
final List<String> scanNames = scanResult.getNamesOfSubclassesOf(SpecificRecordBase.class);
@@ -63,25 +74,33 @@
// Add the header message from the org.apache.reef.wake.avro.message package.
messageClasses.add(Header.class);
- try {
- // Register all of the messages in the specified package.
- for (final Class<?> cls : messageClasses) {
- this.register(cls);
- }
- } catch (final Exception e) {
- throw new RuntimeException("Message registration failed", e);
+ // Register all of the messages in the specified package.
+ for (final Class<?> cls : messageClasses) {
+ this.register(cls);
}
}
/**
+ * Get a canonical string ID of the class. This ID is then used as a key to find
+ * serializer and deserializer of the message payload. We need a separate method
+ * for it to make sure all parties use the same algorithm to get the class ID.
+ * @param clazz class of the message to be serialized/deserialized.
+ * @return canonical string ID of the class.
+ */
+ public static String getClassId(final Class<?> clazz) {
+ return clazz.getCanonicalName();
+ }
+
+ /**
* Instantiates and adds a message serializer/deserializer for the message.
* @param msgMetaClass The reflection class for the message.
* @param <TMessage> The Java type of the message being registered.
*/
public <TMessage> void register(final Class<TMessage> msgMetaClass) {
- LOG.log(Level.INFO, "Registering message: {0}", msgMetaClass.getSimpleName());
- nameToSerializerMap.put(msgMetaClass.getSimpleName(), SerializationFactory.createSerializer(msgMetaClass));
- nameToDeserializerMap.put(msgMetaClass.getSimpleName(), SerializationFactory.createDeserializer(msgMetaClass));
+ final String classId = getClassId(msgMetaClass);
+ LOG.log(Level.INFO, "Registering message: {0}", classId);
+ this.nameToSerializerMap.put(classId, SerializationFactory.createSerializer(msgMetaClass));
+ this.nameToDeserializerMap.put(classId, SerializationFactory.createDeserializer(msgMetaClass));
}
/**
@@ -90,18 +109,21 @@
* @param sequence The unique sequence number of the message.
*/
public byte[] write(final SpecificRecord message, final long sequence) {
- try (final ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) {
- final String name = message.getClass().getSimpleName();
- LOG.log(Level.FINE, "Serializing message: {0}", name);
- final IMessageSerializer serializer = nameToSerializerMap.get(name);
+ final String classId = getClassId(message.getClass());
+ try (final ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) {
+
+ LOG.log(Level.FINEST, "Serializing message: {0}", classId);
+
+ final IMessageSerializer serializer = this.nameToSerializerMap.get(classId);
if (serializer != null) {
serializer.serialize(outputStream, message, sequence);
}
return outputStream.toByteArray();
- } catch (final Exception e) {
- throw new RuntimeException("Failure writing message: " + message.getClass().getCanonicalName(), e);
+
+ } catch (final IOException e) {
+ throw new RuntimeException("Failure writing message: " + classId, e);
}
}
@@ -112,24 +134,29 @@
* to process the deserialized message.
*/
public void read(final byte[] messageBytes, final MultiObserver observer) {
+
try (final InputStream inputStream = new ByteArrayInputStream(messageBytes)) {
+
// Binary decoder for both the header and the message.
final BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null);
// Read the header message.
- final Header header = headerReader.read(null, decoder);
- LOG.log(Level.FINE, "Deserializing Avro message: {0}", header.getClassName());
+ final Header header = this.headerReader.read(null, decoder);
+ final String classId = header.getClassName().toString();
+ LOG.log(Level.FINEST, "Deserializing Avro message: {0}", classId);
// Get the appropriate deserializer and deserialize the message.
- final IMessageDeserializer deserializer = nameToDeserializerMap.get(header.getClassName().toString());
+ final IMessageDeserializer deserializer = this.nameToDeserializerMap.get(classId);
if (deserializer != null) {
deserializer.deserialize(decoder, observer, header.getSequence());
} else {
- throw new RuntimeException("Request to deserialize unknown message type: " + header.getClassName());
+ throw new RuntimeException("Request to deserialize unknown message type: " + classId);
}
- } catch (final Exception e) {
- throw new RuntimeException("Failure reading message: ", e);
+ } catch (final IOException e) {
+ throw new RuntimeException("Failure reading message", e);
+ } catch (final InvocationTargetException | IllegalAccessException e) {
+ throw new RuntimeException("Error deserializing message body", e);
}
}
}
diff --git a/lang/java/reef-wake/wake/src/main/java/org/apache/reef/wake/avro/ProtocolSerializerNamespace.java b/lang/java/reef-wake/wake/src/main/java/org/apache/reef/wake/avro/ProtocolSerializerNamespace.java
new file mode 100644
index 0000000..df0fe35
--- /dev/null
+++ b/lang/java/reef-wake/wake/src/main/java/org/apache/reef/wake/avro/ProtocolSerializerNamespace.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.wake.avro;
+
+import org.apache.reef.tang.annotations.Name;
+import org.apache.reef.tang.annotations.NamedParameter;
+
+/**
+ * ProtocolSerializer parameter: full name of the package containing protocol messages.
+ */
+@NamedParameter(doc = "full name of the package containing protocol messages",
+ short_name = "protocol_serializer_namespace")
+public final class ProtocolSerializerNamespace implements Name<String> {
+ /** Do not instantiate that class. */
+ private ProtocolSerializerNamespace() { }
+}
diff --git a/lang/java/reef-wake/wake/src/main/java/org/apache/reef/wake/avro/impl/MessageSerializerImpl.java b/lang/java/reef-wake/wake/src/main/java/org/apache/reef/wake/avro/impl/MessageSerializerImpl.java
index f225cbb..6844d24 100644
--- a/lang/java/reef-wake/wake/src/main/java/org/apache/reef/wake/avro/impl/MessageSerializerImpl.java
+++ b/lang/java/reef-wake/wake/src/main/java/org/apache/reef/wake/avro/impl/MessageSerializerImpl.java
@@ -23,6 +23,7 @@
import org.apache.avro.specific.SpecificDatumWriter;
import org.apache.avro.specific.SpecificRecord;
import org.apache.reef.wake.avro.IMessageSerializer;
+import org.apache.reef.wake.avro.ProtocolSerializer;
import org.apache.reef.wake.avro.message.Header;
import java.io.ByteArrayOutputStream;
@@ -43,7 +44,7 @@
* @param msgMetaClass The reflection class for the message.
*/
public MessageSerializerImpl(final Class<TMessage> msgMetaClass) {
- this.msgMetaClassName = msgMetaClass.getSimpleName();
+ this.msgMetaClassName = ProtocolSerializer.getClassId(msgMetaClass);
this.messageWriter = new SpecificDatumWriter<>(msgMetaClass);
}
diff --git a/lang/java/reef-wake/wake/src/main/java/org/apache/reef/wake/impl/MultiObserverImpl.java b/lang/java/reef-wake/wake/src/main/java/org/apache/reef/wake/impl/MultiObserverImpl.java
index fb3935d..1ca7b05 100644
--- a/lang/java/reef-wake/wake/src/main/java/org/apache/reef/wake/impl/MultiObserverImpl.java
+++ b/lang/java/reef-wake/wake/src/main/java/org/apache/reef/wake/impl/MultiObserverImpl.java
@@ -31,10 +31,11 @@
/**
* The MultiObserverImpl class uses reflection to discover which onNext()
* event processing methods are defined and then map events to them.
- * @param <TSubCls> The subclass derived from MultiObserverImpl.
*/
-public abstract class MultiObserverImpl<TSubCls> implements MultiObserver {
+public abstract class MultiObserverImpl implements MultiObserver {
+
private static final Logger LOG = Logger.getLogger(MultiObserverImpl.class.getName());
+
private final Map<String, Method> methodMap = new HashMap<>();
/**
@@ -62,8 +63,8 @@
* @param <TEvent> The type of the event being processed.
*/
private <TEvent> void unimplemented(final long identifier, final TEvent event) {
- LOG.log(Level.INFO, "Unimplemented event: [{0}]: {1}",
- new String[]{String.valueOf(identifier), event.getClass().getName()});
+ LOG.log(Level.SEVERE, "Unimplemented event: [{0}]: {1}", new Object[] {identifier, event});
+ throw new RuntimeException("Event not supported: " + event);
}
/**
@@ -74,13 +75,13 @@
*/
@Override
public <TEvent> void onNext(final long identifier, final TEvent event)
- throws IllegalAccessException, InvocationTargetException {
+ throws IllegalAccessException, InvocationTargetException {
// Get the reflection method for this call.
final Method onNext = methodMap.get(event.getClass().getName());
if (onNext != null) {
// Process the event.
- onNext.invoke((TSubCls) this, identifier, event);
+ onNext.invoke(this, identifier, event);
} else {
// Log the unprocessed event.
unimplemented(identifier, event);
diff --git a/lang/java/reef-wake/wake/src/test/java/org/apache/reef/wake/test/avro/ProtocolSerializerTest.java b/lang/java/reef-wake/wake/src/test/java/org/apache/reef/wake/test/avro/ProtocolSerializerTest.java
index 31d142a..33a8c2b 100644
--- a/lang/java/reef-wake/wake/src/test/java/org/apache/reef/wake/test/avro/ProtocolSerializerTest.java
+++ b/lang/java/reef-wake/wake/src/test/java/org/apache/reef/wake/test/avro/ProtocolSerializerTest.java
@@ -18,19 +18,18 @@
*/
package org.apache.reef.wake.test.avro;
+import org.apache.reef.tang.Configuration;
import org.apache.reef.tang.Injector;
import org.apache.reef.tang.Tang;
import org.apache.reef.tang.exceptions.InjectionException;
import org.apache.reef.wake.EventHandler;
import org.apache.reef.wake.avro.ProtocolSerializer;
-import org.apache.reef.wake.impl.LoggingEventHandler;
+import org.apache.reef.wake.avro.ProtocolSerializerNamespace;
import org.apache.reef.wake.impl.MultiObserverImpl;
import org.apache.reef.wake.remote.*;
-import org.apache.reef.wake.remote.address.LocalAddressProvider;
-import org.apache.reef.wake.remote.impl.ByteCodec;
-import org.apache.reef.wake.remote.ports.TcpPortProvider;
import org.apache.reef.wake.test.avro.message.AvroTestMessage;
+import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;
@@ -46,17 +45,37 @@
* exchanged between two remote manager classes.
*/
public final class ProtocolSerializerTest {
- private static final Logger LOG = Logger.getLogger(ProtocolSerializer.class.getName());
+
+ private static final Logger LOG = Logger.getLogger(ProtocolSerializerTest.class.getName());
@Rule
public final TestName name = new TestName();
+ private RemoteManagerFactory remoteManagerFactory;
+ private ProtocolSerializer serializer;
+
+ @Before
+ public void setup() throws InjectionException {
+
+ final Tang tang = Tang.Factory.getTang();
+
+ final Configuration config = tang.newConfigurationBuilder()
+ .bindNamedParameter(ProtocolSerializerNamespace.class, "org.apache.reef.wake.test.avro.message")
+ .build();
+
+ final Injector injector = tang.newInjector(config);
+
+ remoteManagerFactory = injector.getInstance(RemoteManagerFactory.class);
+ serializer = injector.getInstance(ProtocolSerializer.class);
+ }
+
/**
* Verify Avro message can be serialized and deserialized
* between two remote managers.
*/
@Test
- public void testProtocolSerializerTest() throws Exception {
+ public void testProtocolSerializerTest() throws InterruptedException {
+
final int[] numbers = {12, 25};
final String[] strings = {"The first string", "The second string"};
@@ -65,8 +84,8 @@
final BlockingQueue<byte[]> queue2 = new LinkedBlockingQueue<>();
// Remote managers for sending and receiving byte messages.
- final RemoteManager remoteManager1 = getTestRemoteManager("RemoteManagerOne");
- final RemoteManager remoteManager2 = getTestRemoteManager("RemoteManagerTwo");
+ final RemoteManager remoteManager1 = remoteManagerFactory.getInstance("RemoteManagerOne");
+ final RemoteManager remoteManager2 = remoteManagerFactory.getInstance("RemoteManagerTwo");
// Register message handlers for byte level messages.
remoteManager1.registerHandler(byte[].class, new ByteMessageObserver(queue1));
@@ -75,8 +94,6 @@
final EventHandler<byte[]> sender1 = remoteManager1.getHandler(remoteManager2.getMyIdentifier(), byte[].class);
final EventHandler<byte[]> sender2 = remoteManager2.getHandler(remoteManager1.getMyIdentifier(), byte[].class);
- final ProtocolSerializer serializer = new ProtocolSerializer("org.apache.reef.wake.test.avro.message");
-
sender1.onNext(serializer.write(new AvroTestMessage(numbers[0], strings[0]), 1));
sender2.onNext(serializer.write(new AvroTestMessage(numbers[1], strings[1]), 2));
@@ -93,30 +110,8 @@
assertEquals(strings[1], avroObserver1.getDataString());
}
- /**
- * Build a remote manager on the local IP address with an unused port.
- * @param identifier The identifier of the remote manager.
- * @return A RemoteManager instance listing on the local IP address
- * with a unique port number.
- */
- private RemoteManager getTestRemoteManager(final String identifier) throws InjectionException {
- final int port = 0;
- final boolean order = true;
- final int retries = 3;
- final int timeOut = 10000;
-
- final Injector injector = Tang.Factory.getTang().newInjector();
- final LocalAddressProvider localAddressProvider = injector.getInstance(LocalAddressProvider.class);
- final TcpPortProvider tcpPortProvider = injector.getInstance(TcpPortProvider.class);
- final RemoteManagerFactory remoteManagerFactory = injector.getInstance(RemoteManagerFactory.class);
-
- return remoteManagerFactory.getInstance(
- identifier, localAddressProvider.getLocalAddress(), port, new ByteCodec(),
- new LoggingEventHandler<Throwable>(), order, retries, timeOut,
- localAddressProvider, tcpPortProvider);
- }
-
private final class ByteMessageObserver implements EventHandler<RemoteMessage<byte[]>> {
+
private final BlockingQueue<byte[]> queue;
/**
@@ -138,7 +133,8 @@
/**
* Processes messages from the network remote manager.
*/
- public final class AvroMessageObserver extends MultiObserverImpl<AvroMessageObserver> {
+ public final class AvroMessageObserver extends MultiObserverImpl {
+
private int number;
private String dataString;