| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package org.apache.nemo.runtime.executor.bytetransfer; |
| |
| import io.netty.channel.Channel; |
| import io.netty.channel.ChannelHandlerContext; |
| import io.netty.channel.SimpleChannelInboundHandler; |
| import io.netty.channel.group.ChannelGroup; |
| import org.apache.nemo.runtime.common.comm.ControlMessage.ByteTransferContextSetupMessage; |
| import org.apache.nemo.runtime.common.comm.ControlMessage.ByteTransferDataDirection; |
| import org.apache.nemo.runtime.executor.bytetransfer.ByteTransferContext.ContextId; |
| import org.apache.nemo.runtime.executor.data.BlockManagerWorker; |
| import org.apache.nemo.runtime.executor.data.PipeManagerWorker; |
| |
| import java.util.concurrent.ConcurrentHashMap; |
| import java.util.concurrent.ConcurrentMap; |
| import java.util.concurrent.atomic.AtomicInteger; |
| import java.util.function.Function; |
| |
| /** |
| * Manages multiple transport contexts for one channel. |
| */ |
| final class ContextManager extends SimpleChannelInboundHandler<ByteTransferContextSetupMessage> { |
| |
| private final PipeManagerWorker pipeManagerWorker; |
| private final BlockManagerWorker blockManagerWorker; |
| private final ByteTransfer byteTransfer; |
| private final ChannelGroup channelGroup; |
| private final String localExecutorId; |
| private final Channel channel; |
| private volatile String remoteExecutorId = null; |
| |
| private final ConcurrentMap<Integer, ByteInputContext> inputContextsInitiatedByLocal = new ConcurrentHashMap<>(); |
| private final ConcurrentMap<Integer, ByteOutputContext> outputContextsInitiatedByLocal = new ConcurrentHashMap<>(); |
| private final ConcurrentMap<Integer, ByteInputContext> inputContextsInitiatedByRemote = new ConcurrentHashMap<>(); |
| private final ConcurrentMap<Integer, ByteOutputContext> outputContextsInitiatedByRemote = new ConcurrentHashMap<>(); |
| private final AtomicInteger nextInputTransferIndex = new AtomicInteger(0); |
| private final AtomicInteger nextOutputTransferIndex = new AtomicInteger(0); |
| |
| /** |
| * Creates context manager for this channel. |
| * |
| * @param pipeManagerWorker provides handler for new contexts by remote executors |
| * @param blockManagerWorker provides handler for new contexts by remote executors |
| * @param byteTransfer provides channel caching |
| * @param channelGroup to cleanup this channel when closing {@link ByteTransport} |
| * @param localExecutorId local executor id |
| * @param channel the {@link Channel} to manage |
| */ |
| ContextManager(final PipeManagerWorker pipeManagerWorker, |
| final BlockManagerWorker blockManagerWorker, |
| final ByteTransfer byteTransfer, |
| final ChannelGroup channelGroup, |
| final String localExecutorId, |
| final Channel channel) { |
| this.pipeManagerWorker = pipeManagerWorker; |
| this.blockManagerWorker = blockManagerWorker; |
| this.byteTransfer = byteTransfer; |
| this.channelGroup = channelGroup; |
| this.localExecutorId = localExecutorId; |
| this.channel = channel; |
| } |
| |
| /** |
| * @return channel for this context manager. |
| */ |
| Channel getChannel() { |
| return channel; |
| } |
| |
| /** |
| * Returns {@link ByteInputContext} to provide {@link io.netty.buffer.ByteBuf}s on. |
| * |
| * @param dataDirection the data direction |
| * @param transferIndex transfer index |
| * @return the {@link ByteInputContext} corresponding to the pair of {@code dataDirection} and {@code transferIndex} |
| */ |
| ByteInputContext getInputContext(final ByteTransferDataDirection dataDirection, |
| final int transferIndex) { |
| final ConcurrentMap<Integer, ByteInputContext> contexts = |
| dataDirection == ByteTransferDataDirection.INITIATOR_SENDS_DATA |
| ? inputContextsInitiatedByRemote : inputContextsInitiatedByLocal; |
| return contexts.get(transferIndex); |
| } |
| |
| /** |
| * Responds to new transfer contexts by a remote executor. |
| * |
| * @param ctx netty {@link ChannelHandlerContext} |
| * @param message context setup message from the remote executor |
| * @throws Exception exceptions from handler |
| */ |
| @Override |
| protected void channelRead0(final ChannelHandlerContext ctx, final ByteTransferContextSetupMessage message) |
| throws Exception { |
| setRemoteExecutorId(message.getInitiatorExecutorId()); |
| byteTransfer.onNewContextByRemoteExecutor(message.getInitiatorExecutorId(), channel); |
| final ByteTransferDataDirection dataDirection = message.getDataDirection(); |
| final int transferIndex = message.getTransferIndex(); |
| final boolean isPipe = message.getIsPipe(); |
| final ContextId contextId = |
| new ContextId(remoteExecutorId, localExecutorId, dataDirection, transferIndex, isPipe); |
| final byte[] contextDescriptor = message.getContextDescriptor().toByteArray(); |
| |
| if (dataDirection == ByteTransferDataDirection.INITIATOR_SENDS_DATA) { |
| final ByteInputContext context = inputContextsInitiatedByRemote.compute(transferIndex, (index, existing) -> { |
| if (existing != null) { |
| throw new RuntimeException(String.format("Duplicate ContextId: %s", contextId)); |
| } |
| return new ByteInputContext(remoteExecutorId, contextId, contextDescriptor, this); |
| }); |
| |
| if (isPipe) { |
| pipeManagerWorker.onInputContext(context); |
| } else { |
| blockManagerWorker.onInputContext(context); |
| } |
| } else { |
| final ByteOutputContext context = outputContextsInitiatedByRemote.compute(transferIndex, (idx, existing) -> { |
| if (existing != null) { |
| throw new RuntimeException(String.format("Duplicate ContextId: %s", contextId)); |
| } |
| return new ByteOutputContext(remoteExecutorId, contextId, contextDescriptor, this); |
| }); |
| if (isPipe) { |
| pipeManagerWorker.onOutputContext(context); |
| } else { |
| blockManagerWorker.onOutputContext(context); |
| } |
| } |
| } |
| |
| /** |
| * Removes the specified contexts from map. |
| * |
| * @param context the {@link ByteTransferContext} to remove. |
| */ |
| void onContextExpired(final ByteTransferContext context) { |
| final ContextId contextId = context.getContextId(); |
| final ConcurrentMap<Integer, ? extends ByteTransferContext> contexts = context instanceof ByteInputContext |
| ? (contextId.getDataDirection() == ByteTransferDataDirection.INITIATOR_SENDS_DATA |
| ? inputContextsInitiatedByRemote : inputContextsInitiatedByLocal) |
| : (contextId.getDataDirection() == ByteTransferDataDirection.INITIATOR_SENDS_DATA |
| ? outputContextsInitiatedByLocal : outputContextsInitiatedByRemote); |
| contexts.remove(contextId.getTransferIndex(), context); |
| } |
| |
| /** |
| * Initiates a context and stores to the specified map. |
| * |
| * @param contexts map for storing context |
| * @param transferIndexCounter counter for generating transfer index |
| * @param dataDirection data direction to include in the context id |
| * @param contextGenerator a function that returns context from context id |
| * @param executorId id of the remote executor |
| * @param <T> {@link ByteInputContext} or {@link ByteOutputContext} |
| * @param isPipe is a pipe context |
| * @return generated context |
| */ |
| <T extends ByteTransferContext> T newContext(final ConcurrentMap<Integer, T> contexts, |
| final AtomicInteger transferIndexCounter, |
| final ByteTransferDataDirection dataDirection, |
| final Function<ContextId, T> contextGenerator, |
| final String executorId, |
| final boolean isPipe) { |
| setRemoteExecutorId(executorId); |
| final int transferIndex = transferIndexCounter.getAndIncrement(); |
| final ContextId contextId = new ContextId(localExecutorId, executorId, dataDirection, transferIndex, isPipe); |
| final T context = contexts.compute(transferIndex, (index, existingContext) -> { |
| if (existingContext != null) { |
| throw new RuntimeException(String.format("Duplicate ContextId: %s", contextId)); |
| } |
| return contextGenerator.apply(contextId); |
| }); |
| channel.writeAndFlush(context).addListener(context.getChannelWriteListener()); |
| return context; |
| } |
| |
| /** |
| * Create a new {@link ByteInputContext}. |
| * |
| * @param executorId target executor id |
| * @param contextDescriptor the context descriptor |
| * @param isPipe is pipe |
| * @return new {@link ByteInputContext} |
| */ |
| ByteInputContext newInputContext(final String executorId, final byte[] contextDescriptor, final boolean isPipe) { |
| return newContext(inputContextsInitiatedByLocal, nextInputTransferIndex, |
| ByteTransferDataDirection.INITIATOR_RECEIVES_DATA, |
| contextId -> new ByteInputContext(executorId, contextId, contextDescriptor, this), |
| executorId, isPipe); |
| } |
| |
| /** |
| * Create a new {@link ByteOutputContext}. |
| * |
| * @param executorId target executor id |
| * @param contextDescriptor the context descriptor |
| * @param isPipe is pipe |
| * @return new {@link ByteOutputContext} |
| */ |
| ByteOutputContext newOutputContext(final String executorId, final byte[] contextDescriptor, final boolean isPipe) { |
| return newContext(outputContextsInitiatedByLocal, nextOutputTransferIndex, |
| ByteTransferDataDirection.INITIATOR_SENDS_DATA, |
| contextId -> new ByteOutputContext(executorId, contextId, contextDescriptor, this), |
| executorId, isPipe); |
| } |
| |
| /** |
| * Set this contest manager as connected to the specified remote executor. |
| * |
| * @param executorId the remote executor id |
| */ |
| private void setRemoteExecutorId(final String executorId) { |
| if (remoteExecutorId == null) { |
| remoteExecutorId = executorId; |
| } else if (!executorId.equals(remoteExecutorId)) { |
| throw new RuntimeException(String.format("Wrong ContextManager: (%s != %s)", executorId, remoteExecutorId)); |
| } |
| } |
| |
| @Override |
| public void channelActive(final ChannelHandlerContext ctx) { |
| channelGroup.add(ctx.channel()); |
| } |
| |
| @Override |
| public void channelInactive(final ChannelHandlerContext ctx) { |
| channelGroup.remove(ctx.channel()); |
| final Throwable cause = new Exception("Channel closed"); |
| throwChannelErrorOnContexts(inputContextsInitiatedByLocal, cause); |
| throwChannelErrorOnContexts(outputContextsInitiatedByLocal, cause); |
| throwChannelErrorOnContexts(inputContextsInitiatedByRemote, cause); |
| throwChannelErrorOnContexts(outputContextsInitiatedByRemote, cause); |
| } |
| |
| /** |
| * Invoke {@link ByteTransferContext#onChannelError(Throwable)} on the specified contexts. |
| * |
| * @param contexts map storing the contexts |
| * @param cause the error |
| * @param <T> {@link ByteInputContext} or {@link ByteOutputContext} |
| */ |
| private <T extends ByteTransferContext> void throwChannelErrorOnContexts(final ConcurrentMap<Integer, T> contexts, |
| final Throwable cause) { |
| for (final ByteTransferContext context : contexts.values()) { |
| context.onChannelError(cause); |
| } |
| } |
| } |