blob: bfbbda9665e49b63523b617572b005a0df611c66 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.nemo.runtime.executor.bytetransfer;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.group.ChannelGroup;
import org.apache.nemo.runtime.common.comm.ControlMessage.ByteTransferContextSetupMessage;
import org.apache.nemo.runtime.common.comm.ControlMessage.ByteTransferDataDirection;
import org.apache.nemo.runtime.executor.bytetransfer.ByteTransferContext.ContextId;
import org.apache.nemo.runtime.executor.data.BlockManagerWorker;
import org.apache.nemo.runtime.executor.data.PipeManagerWorker;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
/**
* Manages multiple transport contexts for one channel.
*/
final class ContextManager extends SimpleChannelInboundHandler<ByteTransferContextSetupMessage> {
private final PipeManagerWorker pipeManagerWorker;
private final BlockManagerWorker blockManagerWorker;
private final ByteTransfer byteTransfer;
private final ChannelGroup channelGroup;
private final String localExecutorId;
private final Channel channel;
private volatile String remoteExecutorId = null;
private final ConcurrentMap<Integer, ByteInputContext> inputContextsInitiatedByLocal = new ConcurrentHashMap<>();
private final ConcurrentMap<Integer, ByteOutputContext> outputContextsInitiatedByLocal = new ConcurrentHashMap<>();
private final ConcurrentMap<Integer, ByteInputContext> inputContextsInitiatedByRemote = new ConcurrentHashMap<>();
private final ConcurrentMap<Integer, ByteOutputContext> outputContextsInitiatedByRemote = new ConcurrentHashMap<>();
private final AtomicInteger nextInputTransferIndex = new AtomicInteger(0);
private final AtomicInteger nextOutputTransferIndex = new AtomicInteger(0);
/**
* Creates context manager for this channel.
*
* @param pipeManagerWorker provides handler for new contexts by remote executors
* @param blockManagerWorker provides handler for new contexts by remote executors
* @param byteTransfer provides channel caching
* @param channelGroup to cleanup this channel when closing {@link ByteTransport}
* @param localExecutorId local executor id
* @param channel the {@link Channel} to manage
*/
ContextManager(final PipeManagerWorker pipeManagerWorker,
final BlockManagerWorker blockManagerWorker,
final ByteTransfer byteTransfer,
final ChannelGroup channelGroup,
final String localExecutorId,
final Channel channel) {
this.pipeManagerWorker = pipeManagerWorker;
this.blockManagerWorker = blockManagerWorker;
this.byteTransfer = byteTransfer;
this.channelGroup = channelGroup;
this.localExecutorId = localExecutorId;
this.channel = channel;
}
/**
* @return channel for this context manager.
*/
Channel getChannel() {
return channel;
}
/**
* Returns {@link ByteInputContext} to provide {@link io.netty.buffer.ByteBuf}s on.
*
* @param dataDirection the data direction
* @param transferIndex transfer index
* @return the {@link ByteInputContext} corresponding to the pair of {@code dataDirection} and {@code transferIndex}
*/
ByteInputContext getInputContext(final ByteTransferDataDirection dataDirection,
final int transferIndex) {
final ConcurrentMap<Integer, ByteInputContext> contexts =
dataDirection == ByteTransferDataDirection.INITIATOR_SENDS_DATA
? inputContextsInitiatedByRemote : inputContextsInitiatedByLocal;
return contexts.get(transferIndex);
}
/**
* Responds to new transfer contexts by a remote executor.
*
* @param ctx netty {@link ChannelHandlerContext}
* @param message context setup message from the remote executor
* @throws Exception exceptions from handler
*/
@Override
protected void channelRead0(final ChannelHandlerContext ctx, final ByteTransferContextSetupMessage message)
throws Exception {
setRemoteExecutorId(message.getInitiatorExecutorId());
byteTransfer.onNewContextByRemoteExecutor(message.getInitiatorExecutorId(), channel);
final ByteTransferDataDirection dataDirection = message.getDataDirection();
final int transferIndex = message.getTransferIndex();
final boolean isPipe = message.getIsPipe();
final ContextId contextId =
new ContextId(remoteExecutorId, localExecutorId, dataDirection, transferIndex, isPipe);
final byte[] contextDescriptor = message.getContextDescriptor().toByteArray();
if (dataDirection == ByteTransferDataDirection.INITIATOR_SENDS_DATA) {
final ByteInputContext context = inputContextsInitiatedByRemote.compute(transferIndex, (index, existing) -> {
if (existing != null) {
throw new RuntimeException(String.format("Duplicate ContextId: %s", contextId));
}
return new ByteInputContext(remoteExecutorId, contextId, contextDescriptor, this);
});
if (isPipe) {
pipeManagerWorker.onInputContext(context);
} else {
blockManagerWorker.onInputContext(context);
}
} else {
final ByteOutputContext context = outputContextsInitiatedByRemote.compute(transferIndex, (idx, existing) -> {
if (existing != null) {
throw new RuntimeException(String.format("Duplicate ContextId: %s", contextId));
}
return new ByteOutputContext(remoteExecutorId, contextId, contextDescriptor, this);
});
if (isPipe) {
pipeManagerWorker.onOutputContext(context);
} else {
blockManagerWorker.onOutputContext(context);
}
}
}
/**
* Removes the specified contexts from map.
*
* @param context the {@link ByteTransferContext} to remove.
*/
void onContextExpired(final ByteTransferContext context) {
final ContextId contextId = context.getContextId();
final ConcurrentMap<Integer, ? extends ByteTransferContext> contexts = context instanceof ByteInputContext
? (contextId.getDataDirection() == ByteTransferDataDirection.INITIATOR_SENDS_DATA
? inputContextsInitiatedByRemote : inputContextsInitiatedByLocal)
: (contextId.getDataDirection() == ByteTransferDataDirection.INITIATOR_SENDS_DATA
? outputContextsInitiatedByLocal : outputContextsInitiatedByRemote);
contexts.remove(contextId.getTransferIndex(), context);
}
/**
* Initiates a context and stores to the specified map.
*
* @param contexts map for storing context
* @param transferIndexCounter counter for generating transfer index
* @param dataDirection data direction to include in the context id
* @param contextGenerator a function that returns context from context id
* @param executorId id of the remote executor
* @param <T> {@link ByteInputContext} or {@link ByteOutputContext}
* @param isPipe is a pipe context
* @return generated context
*/
<T extends ByteTransferContext> T newContext(final ConcurrentMap<Integer, T> contexts,
final AtomicInteger transferIndexCounter,
final ByteTransferDataDirection dataDirection,
final Function<ContextId, T> contextGenerator,
final String executorId,
final boolean isPipe) {
setRemoteExecutorId(executorId);
final int transferIndex = transferIndexCounter.getAndIncrement();
final ContextId contextId = new ContextId(localExecutorId, executorId, dataDirection, transferIndex, isPipe);
final T context = contexts.compute(transferIndex, (index, existingContext) -> {
if (existingContext != null) {
throw new RuntimeException(String.format("Duplicate ContextId: %s", contextId));
}
return contextGenerator.apply(contextId);
});
channel.writeAndFlush(context).addListener(context.getChannelWriteListener());
return context;
}
/**
* Create a new {@link ByteInputContext}.
*
* @param executorId target executor id
* @param contextDescriptor the context descriptor
* @param isPipe is pipe
* @return new {@link ByteInputContext}
*/
ByteInputContext newInputContext(final String executorId, final byte[] contextDescriptor, final boolean isPipe) {
return newContext(inputContextsInitiatedByLocal, nextInputTransferIndex,
ByteTransferDataDirection.INITIATOR_RECEIVES_DATA,
contextId -> new ByteInputContext(executorId, contextId, contextDescriptor, this),
executorId, isPipe);
}
/**
* Create a new {@link ByteOutputContext}.
*
* @param executorId target executor id
* @param contextDescriptor the context descriptor
* @param isPipe is pipe
* @return new {@link ByteOutputContext}
*/
ByteOutputContext newOutputContext(final String executorId, final byte[] contextDescriptor, final boolean isPipe) {
return newContext(outputContextsInitiatedByLocal, nextOutputTransferIndex,
ByteTransferDataDirection.INITIATOR_SENDS_DATA,
contextId -> new ByteOutputContext(executorId, contextId, contextDescriptor, this),
executorId, isPipe);
}
/**
* Set this contest manager as connected to the specified remote executor.
*
* @param executorId the remote executor id
*/
private void setRemoteExecutorId(final String executorId) {
if (remoteExecutorId == null) {
remoteExecutorId = executorId;
} else if (!executorId.equals(remoteExecutorId)) {
throw new RuntimeException(String.format("Wrong ContextManager: (%s != %s)", executorId, remoteExecutorId));
}
}
@Override
public void channelActive(final ChannelHandlerContext ctx) {
channelGroup.add(ctx.channel());
}
@Override
public void channelInactive(final ChannelHandlerContext ctx) {
channelGroup.remove(ctx.channel());
final Throwable cause = new Exception("Channel closed");
throwChannelErrorOnContexts(inputContextsInitiatedByLocal, cause);
throwChannelErrorOnContexts(outputContextsInitiatedByLocal, cause);
throwChannelErrorOnContexts(inputContextsInitiatedByRemote, cause);
throwChannelErrorOnContexts(outputContextsInitiatedByRemote, cause);
}
/**
* Invoke {@link ByteTransferContext#onChannelError(Throwable)} on the specified contexts.
*
* @param contexts map storing the contexts
* @param cause the error
* @param <T> {@link ByteInputContext} or {@link ByteOutputContext}
*/
private <T extends ByteTransferContext> void throwChannelErrorOnContexts(final ConcurrentMap<Integer, T> contexts,
final Throwable cause) {
for (final ByteTransferContext context : contexts.values()) {
context.onChannelError(cause);
}
}
}