blob: 5a45e0cf9c06b90a210dd3035c901859669cf1b4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.nemo.runtime.common.message.ncs;
import org.apache.nemo.runtime.common.ReplyFutureMap;
import org.apache.nemo.runtime.common.comm.ControlMessage;
import org.apache.nemo.runtime.common.message.*;
import org.apache.reef.exception.evaluator.NetworkException;
import org.apache.reef.io.network.Connection;
import org.apache.reef.io.network.ConnectionFactory;
import org.apache.reef.io.network.Message;
import org.apache.reef.io.network.NetworkConnectionService;
import org.apache.reef.tang.annotations.Parameter;
import org.apache.reef.wake.EventHandler;
import org.apache.reef.wake.IdentifierFactory;
import org.apache.reef.wake.remote.transport.LinkListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.inject.Inject;
import java.net.SocketAddress;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Future;
/**
* Message environment for NCS.
*/
public final class NcsMessageEnvironment implements MessageEnvironment {
private static final Logger LOG = LoggerFactory.getLogger(NcsMessageEnvironment.class.getName());
private static final String NCS_CONN_FACTORY_ID = "NCS_CONN_FACTORY_ID";
private final NetworkConnectionService networkConnectionService;
private final IdentifierFactory idFactory;
private final String senderId;
private final ReplyFutureMap<ControlMessage.Message> replyFutureMap;
private final ConcurrentMap<String, MessageListener> listenerConcurrentMap;
private final Map<String, Connection> receiverToConnectionMap;
private final ConnectionFactory<ControlMessage.Message> connectionFactory;
@Inject
private NcsMessageEnvironment(
final NetworkConnectionService networkConnectionService,
final IdentifierFactory idFactory,
@Parameter(MessageParameters.SenderId.class) final String senderId) {
this.networkConnectionService = networkConnectionService;
this.idFactory = idFactory;
this.senderId = senderId;
this.replyFutureMap = new ReplyFutureMap<>();
this.listenerConcurrentMap = new ConcurrentHashMap<>();
this.receiverToConnectionMap = new ConcurrentHashMap<>();
this.connectionFactory = networkConnectionService.registerConnectionFactory(
idFactory.getNewInstance(NCS_CONN_FACTORY_ID),
new ControlMessageCodec(),
new NcsMessageHandler(),
new NcsLinkListener(),
idFactory.getNewInstance(senderId));
}
@Override
public <T> void setupListener(final String listenerId, final MessageListener<T> listener) {
if (listenerConcurrentMap.putIfAbsent(listenerId, listener) != null) {
throw new RuntimeException("A listener for " + listenerId + " was already setup");
}
}
@Override
public void removeListener(final String listenerId) {
listenerConcurrentMap.remove(listenerId);
}
@Override
public <T> Future<MessageSender<T>> asyncConnect(final String receiverId, final String listenerId) {
try {
// If the connection toward the receiver exists already, reuses it.
final Connection connection;
if (receiverToConnectionMap.containsKey(receiverId)) {
connection = receiverToConnectionMap.get(receiverId);
} else {
connection = connectionFactory.newConnection(idFactory.getNewInstance(receiverId));
connection.open();
}
return CompletableFuture.completedFuture((MessageSender) new NcsMessageSender(connection, replyFutureMap));
} catch (final NetworkException e) {
final CompletableFuture<MessageSender<T>> failedFuture = new CompletableFuture<>();
failedFuture.completeExceptionally(e);
return failedFuture;
}
}
@Override
public String getId() {
return senderId;
}
@Override
public void close() throws Exception {
networkConnectionService.close();
}
/**
* Message handler for NCS.
*/
private final class NcsMessageHandler implements EventHandler<Message<ControlMessage.Message>> {
public void onNext(final Message<ControlMessage.Message> messages) {
final ControlMessage.Message controlMessage = extractSingleMessage(messages);
final MessageType messageType = getMsgType(controlMessage);
switch (messageType) {
case Send:
processSendMessage(controlMessage);
break;
case Request:
processRequestMessage(controlMessage);
break;
case Reply:
processReplyMessage(controlMessage);
break;
default:
throw new IllegalArgumentException(controlMessage.toString());
}
}
private void processSendMessage(final ControlMessage.Message controlMessage) {
final String listenerId = controlMessage.getListenerId();
listenerConcurrentMap.get(listenerId).onMessage(controlMessage);
}
private void processRequestMessage(final ControlMessage.Message controlMessage) {
final String listenerId = controlMessage.getListenerId();
final String executorId = getExecutorId(controlMessage);
final MessageContext messageContext = new NcsMessageContext(executorId, connectionFactory, idFactory);
listenerConcurrentMap.get(listenerId).onMessageWithContext(controlMessage, messageContext);
}
private void processReplyMessage(final ControlMessage.Message controlMessage) {
final long requestId = getRequestId(controlMessage);
replyFutureMap.onSuccessMessage(requestId, controlMessage);
}
}
/**
* LinkListener for NCS.
*/
private final class NcsLinkListener implements LinkListener<Message<ControlMessage.Message>> {
public void onSuccess(final Message<ControlMessage.Message> messages) {
// No-ops.
}
public void onException(final Throwable throwable,
final SocketAddress socketAddress,
final Message<ControlMessage.Message> messages) {
// TODO #140: Properly classify and handle each RPC failure
// Not logging the stacktrace here, as it's not very useful.
LOG.error("NCS Exception");
}
}
private ControlMessage.Message extractSingleMessage(final Message<ControlMessage.Message> messages) {
return messages.getData().iterator().next();
}
/**
* Send: Messages sent without expecting a reply.
* Request: Messages sent to get a reply.
* Reply: Messages that reply to a request.
* <p>
* Not sure these variable names are conventionally used in RPC frameworks...
* Let's revisit them when we work on
*/
enum MessageType {
Send,
Request,
Reply
}
private MessageType getMsgType(final ControlMessage.Message controlMessage) {
switch (controlMessage.getType()) {
case TaskStateChanged:
case ScheduleTask:
case BlockStateChanged:
case ExecutorFailed:
case RunTimePassMessage:
case ExecutorDataCollected:
case MetricMessageReceived:
case RequestMetricFlush:
case MetricFlushed:
case PipeInit:
return MessageType.Send;
case RequestBlockLocation:
case RequestBroadcastVariable:
case RequestPipeLoc:
return MessageType.Request;
case BlockLocationInfo:
case InMasterBroadcastVariable:
case PipeLocInfo:
return MessageType.Reply;
default:
throw new IllegalArgumentException(controlMessage.toString());
}
}
private String getExecutorId(final ControlMessage.Message controlMessage) {
switch (controlMessage.getType()) {
case RequestBlockLocation:
return controlMessage.getRequestBlockLocationMsg().getExecutorId();
case RequestBroadcastVariable:
return controlMessage.getRequestbroadcastVariableMsg().getExecutorId();
case RequestPipeLoc:
return controlMessage.getRequestPipeLocMsg().getExecutorId();
default:
throw new IllegalArgumentException(controlMessage.toString());
}
}
private long getRequestId(final ControlMessage.Message controlMessage) {
switch (controlMessage.getType()) {
case BlockLocationInfo:
return controlMessage.getBlockLocationInfoMsg().getRequestId();
case InMasterBroadcastVariable:
return controlMessage.getBroadcastVariableMsg().getRequestId();
case PipeLocInfo:
return controlMessage.getPipeLocInfoMsg().getRequestId();
default:
throw new IllegalArgumentException(controlMessage.toString());
}
}
}