blob: 9c0c5f0c976d78a3f04d9180bf8b5135df276b9c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.nemo.runtime.executor.transfer;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.ChannelGroupFuture;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.util.concurrent.DefaultThreadFactory;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.apache.nemo.conf.JobConf;
import org.apache.nemo.runtime.common.NettyChannelImplementationSelector;
import org.apache.reef.io.network.naming.NameResolver;
import org.apache.reef.tang.annotations.Parameter;
import org.apache.reef.wake.Identifier;
import org.apache.reef.wake.remote.address.LocalAddressProvider;
import org.apache.reef.wake.remote.ports.TcpPortProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.inject.Inject;
import java.net.InetSocketAddress;
/**
* Bootstraps the server and connects to other servers on demand.
*/
final class ByteTransport implements AutoCloseable {
private static final Logger LOG = LoggerFactory.getLogger(ByteTransport.class);
private static final String SERVER_LISTENING = "byte:server:listening";
private static final String SERVER_WORKING = "byte:server:working";
private static final String CLIENT = "byte:client";
private final ChannelGroup channelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
private final NameResolver nameResolver;
private final EventLoopGroup serverListeningGroup;
private final EventLoopGroup serverWorkingGroup;
private final EventLoopGroup clientGroup;
private final Bootstrap clientBootstrap;
private final Channel serverListeningChannel;
/**
* Constructs a byte transport and starts listening.
*
* @param nameResolver provides naming registry
* @param localExecutorId the id of this executor
* @param channelImplSelector provides implementation for netty channel
* @param channelInitializer initializes channel pipeline
* @param tcpPortProvider provides an iterator of random tcp ports
* @param localAddressProvider provides the local address of the node to bind to
* @param port the listening port; 0 means random assign using {@code tcpPortProvider}
* @param serverBacklog the maximum number of pending connections to the server
* @param numListeningThreads the number of listening threads of the server
* @param numWorkingThreads the number of working threads of the server
* @param numClientThreads the number of client threads
*/
@Inject
private ByteTransport(
final NameResolver nameResolver,
@Parameter(JobConf.ExecutorId.class) final String localExecutorId,
final NettyChannelImplementationSelector channelImplSelector,
final ByteTransportChannelInitializer channelInitializer,
final TcpPortProvider tcpPortProvider,
final LocalAddressProvider localAddressProvider,
@Parameter(JobConf.PartitionTransportServerPort.class) final int port,
@Parameter(JobConf.PartitionTransportServerBacklog.class) final int serverBacklog,
@Parameter(JobConf.PartitionTransportServerNumListeningThreads.class) final int numListeningThreads,
@Parameter(JobConf.PartitionTransportServerNumWorkingThreads.class) final int numWorkingThreads,
@Parameter(JobConf.PartitionTransportClientNumThreads.class) final int numClientThreads) {
this.nameResolver = nameResolver;
if (port < 0) {
throw new IllegalArgumentException(String.format("Invalid ByteTransportPort: %d", port));
}
final String host = localAddressProvider.getLocalAddress();
serverListeningGroup = channelImplSelector.newEventLoopGroup(numListeningThreads,
new DefaultThreadFactory(SERVER_LISTENING));
serverWorkingGroup = channelImplSelector.newEventLoopGroup(numWorkingThreads,
new DefaultThreadFactory(SERVER_WORKING));
clientGroup = channelImplSelector.newEventLoopGroup(numClientThreads, new DefaultThreadFactory(CLIENT));
clientBootstrap = new Bootstrap()
.group(clientGroup)
.channel(channelImplSelector.getChannelClass())
.handler(channelInitializer)
.option(ChannelOption.SO_REUSEADDR, true);
final ServerBootstrap serverBootstrap = new ServerBootstrap()
.group(serverListeningGroup, serverWorkingGroup)
.channel(channelImplSelector.getServerChannelClass())
.childHandler(channelInitializer)
.option(ChannelOption.SO_BACKLOG, serverBacklog)
.option(ChannelOption.SO_REUSEADDR, true);
Channel listeningChannel = null;
if (port == 0) {
for (final int candidatePort : tcpPortProvider) {
try {
final ChannelFuture future = serverBootstrap.bind(host, candidatePort).await();
if (future.cause() != null) {
LOG.debug(String.format("Cannot bind to %s:%d", host, candidatePort), future.cause());
} else if (!future.isSuccess()) {
LOG.debug("Cannot bind to {}:{}", host, candidatePort);
} else {
listeningChannel = future.channel();
break;
}
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
LOG.debug(String.format("Interrupted while binding to %s:%d", host, candidatePort), e);
}
}
if (listeningChannel == null) {
serverListeningGroup.shutdownGracefully();
serverWorkingGroup.shutdownGracefully();
clientGroup.shutdownGracefully();
LOG.error("Cannot bind to {} with tcpPortProvider", host);
throw new RuntimeException(String.format("Cannot bind to %s with tcpPortProvider", host));
}
} else {
try {
final ChannelFuture future = serverBootstrap.bind(host, port).await();
if (future.cause() != null) {
throw future.cause();
} else if (!future.isSuccess()) {
throw new RuntimeException("Cannot bind");
} else {
listeningChannel = future.channel();
}
} catch (final Throwable e) {
serverListeningGroup.shutdownGracefully();
serverWorkingGroup.shutdownGracefully();
clientGroup.shutdownGracefully();
LOG.error(String.format("Cannot bind to %s:%d when bind ChannelFuture", host, port), e);
throw new RuntimeException(String.format("Cannot bind to %s:%d when connect ChannelFuture", host, port), e);
}
}
serverListeningChannel = listeningChannel;
try {
final ByteTransportIdentifier identifier = new ByteTransportIdentifier(localExecutorId);
nameResolver.register(identifier, (InetSocketAddress) listeningChannel.localAddress());
} catch (final Exception e) {
LOG.error("Cannot register ByteTransport listening address to the naming registry", e);
throw new RuntimeException(e);
}
LOG.info("ByteTransport server in {} is listening at {}", localExecutorId, listeningChannel.localAddress());
}
/**
* Closes all channels and releases all resources.
*/
@Override
public void close() {
LOG.info("Stopping listening at {} and closing", serverListeningChannel.localAddress());
final ChannelFuture closeListeningChannelFuture = serverListeningChannel.close();
final ChannelGroupFuture channelGroupCloseFuture = channelGroup.close();
final Future serverListeningGroupCloseFuture = serverListeningGroup.shutdownGracefully();
final Future serverWorkingGroupCloseFuture = serverWorkingGroup.shutdownGracefully();
final Future clientGroupCloseFuture = clientGroup.shutdownGracefully();
closeListeningChannelFuture.awaitUninterruptibly();
channelGroupCloseFuture.awaitUninterruptibly();
serverListeningGroupCloseFuture.awaitUninterruptibly();
serverWorkingGroupCloseFuture.awaitUninterruptibly();
clientGroupCloseFuture.awaitUninterruptibly();
}
/**
* Connect to the {@link ByteTransport} server of the specified executor.
*
* @param remoteExecutorId the id of the executor
* @return a {@link ChannelFuture} for connecting
*/
ChannelFuture connectTo(final String remoteExecutorId) {
final InetSocketAddress address;
try {
final ByteTransportIdentifier identifier = new ByteTransportIdentifier(remoteExecutorId);
address = nameResolver.lookup(identifier);
} catch (final Exception e) {
LOG.error(String.format("Cannot lookup ByteTransport listening address of %s", remoteExecutorId), e);
throw new RuntimeException(e);
}
final ChannelFuture connectFuture = clientBootstrap.connect(address);
connectFuture.addListener(future -> {
if (future.isSuccess()) {
// Succeed to connect
LOG.debug("Connected to {}", remoteExecutorId);
return;
}
// Failed to connect (Not logging the cause here, which is not very useful)
LOG.error("Failed to connect to {}", remoteExecutorId);
});
return connectFuture;
}
/**
* @return {@link ChannelGroup} for active connections between this executor and remote executor.
*/
ChannelGroup getChannelGroup() {
return channelGroup;
}
/**
* {@link Identifier} for {@link ByteTransfer}.
*/
private static final class ByteTransportIdentifier implements Identifier {
private final String executorId;
/**
* Creates a {@link ByteTransportIdentifier}.
*
* @param executorId id of the {@link org.apache.nemo.runtime.executor.Executor}
*/
private ByteTransportIdentifier(final String executorId) {
this.executorId = executorId;
}
@Override
public String toString() {
return "byte://" + executorId;
}
@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final ByteTransportIdentifier that = (ByteTransportIdentifier) o;
return executorId.equals(that.executorId);
}
@Override
public int hashCode() {
return executorId.hashCode();
}
}
}