blob: 9c8fb6c67e91ca5e75b54fb638392dd21e3d28c1 [file] [log] [blame]
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.dag.app;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URISyntaxException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.Objects;
import com.google.common.collect.Maps;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.ipc.ProtocolSignature;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.ipc.Server;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.authorize.PolicyProvider;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.tez.common.ContainerContext;
import org.apache.tez.common.ContainerTask;
import org.apache.tez.common.TezConverterUtils;
import org.apache.tez.common.TezLocalResource;
import org.apache.tez.common.TezTaskUmbilicalProtocol;
import org.apache.tez.common.TezUtils;
import org.apache.tez.common.security.JobTokenIdentifier;
import org.apache.tez.common.security.JobTokenSecretManager;
import org.apache.tez.common.security.TokenCache;
import org.apache.tez.serviceplugins.api.ContainerEndReason;
import org.apache.tez.serviceplugins.api.TaskAttemptEndReason;
import org.apache.tez.serviceplugins.api.TaskCommunicator;
import org.apache.tez.serviceplugins.api.TaskCommunicatorContext;
import org.apache.tez.serviceplugins.api.TaskHeartbeatRequest;
import org.apache.tez.serviceplugins.api.TaskHeartbeatResponse;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.dag.api.event.VertexStateUpdate;
import org.apache.tez.dag.app.security.authorize.TezAMPolicyProvider;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.runtime.api.impl.TaskSpec;
import org.apache.tez.runtime.api.impl.TezHeartbeatRequest;
import org.apache.tez.runtime.api.impl.TezHeartbeatResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@InterfaceAudience.Private
public class TezTaskCommunicatorImpl extends TaskCommunicator {
private static final Logger LOG = LoggerFactory.getLogger(TezTaskCommunicatorImpl.class);
private static final ContainerTask TASK_FOR_INVALID_JVM = new ContainerTask(
null, true, null, null, false);
private final TezTaskUmbilicalProtocol taskUmbilical;
protected final ConcurrentMap<ContainerId, ContainerInfo> registeredContainers =
new ConcurrentHashMap<>();
protected final ConcurrentMap<TezTaskAttemptID, ContainerId> attemptToContainerMap =
new ConcurrentHashMap<>();
protected final String tokenIdentifier;
protected final Token<JobTokenIdentifier> sessionToken;
protected final Configuration conf;
protected InetSocketAddress address;
protected volatile Server server;
public static final class ContainerInfo {
ContainerInfo(ContainerId containerId, String host, int port) {
this.containerId = containerId;
this.host = host;
this.port = port;
}
final ContainerId containerId;
public final String host;
public final int port;
TezHeartbeatResponse lastResponse = null;
TaskSpec taskSpec = null;
long lastRequestId = 0;
Map<String, LocalResource> additionalLRs = null;
Credentials credentials = null;
boolean credentialsChanged = false;
boolean taskPulled = false;
void reset() {
taskSpec = null;
additionalLRs = null;
credentials = null;
credentialsChanged = false;
taskPulled = false;
}
}
/**
* Construct the service.
*/
public TezTaskCommunicatorImpl(TaskCommunicatorContext taskCommunicatorContext) {
super(taskCommunicatorContext);
this.taskUmbilical = new TezTaskUmbilicalProtocolImpl();
this.tokenIdentifier = taskCommunicatorContext.getApplicationAttemptId().getApplicationId().toString();
this.sessionToken = TokenCache.getSessionToken(taskCommunicatorContext.getAMCredentials());
try {
conf = TezUtils.createConfFromUserPayload(getContext().getInitialUserPayload());
} catch (IOException e) {
throw new TezUncheckedException(
"Unable to parse user payload for " + TezTaskCommunicatorImpl.class.getSimpleName(), e);
}
}
@Override
public void start() {
startRpcServer();
}
@Override
public void shutdown() {
stopRpcServer();
}
protected void startRpcServer() {
try {
JobTokenSecretManager jobTokenSecretManager =
new JobTokenSecretManager();
jobTokenSecretManager.addTokenForJob(tokenIdentifier, sessionToken);
server = new RPC.Builder(conf)
.setProtocol(TezTaskUmbilicalProtocol.class)
.setBindAddress("0.0.0.0")
.setPort(0)
.setInstance(taskUmbilical)
.setNumHandlers(
conf.getInt(TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT,
TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT_DEFAULT))
.setPortRangeConfig(TezConfiguration.TEZ_AM_TASK_AM_PORT_RANGE)
.setSecretManager(jobTokenSecretManager).build();
// Enable service authorization?
if (conf.getBoolean(
CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION,
false)) {
refreshServiceAcls(conf, new TezAMPolicyProvider());
}
server.start();
InetSocketAddress serverBindAddress = NetUtils.getConnectAddress(server);
this.address = NetUtils.createSocketAddrForHost(
serverBindAddress.getAddress().getCanonicalHostName(),
serverBindAddress.getPort());
LOG.info("Instantiated TezTaskCommunicator RPC at " + this.address);
} catch (IOException e) {
throw new TezUncheckedException(e);
}
}
protected void stopRpcServer() {
if (server != null) {
server.stop();
server = null;
}
}
protected Configuration getConf() {
return this.conf;
}
private void refreshServiceAcls(Configuration configuration,
PolicyProvider policyProvider) {
this.server.refreshServiceAcl(configuration, policyProvider);
}
@Override
public void registerRunningContainer(ContainerId containerId, String host, int port) {
ContainerInfo oldInfo = registeredContainers.putIfAbsent(containerId,
new ContainerInfo(containerId, host, port));
if (oldInfo != null) {
throw new TezUncheckedException("Multiple registrations for containerId: " + containerId);
}
}
@Override
public void registerContainerEnd(ContainerId containerId, ContainerEndReason endReason, String diagnostics) {
ContainerInfo containerInfo = registeredContainers.remove(containerId);
if (containerInfo != null) {
synchronized(containerInfo) {
if (containerInfo.taskSpec != null && containerInfo.taskSpec.getTaskAttemptID() != null) {
attemptToContainerMap.remove(containerInfo.taskSpec.getTaskAttemptID());
}
}
}
}
@Override
public void registerRunningTaskAttempt(ContainerId containerId, TaskSpec taskSpec,
Map<String, LocalResource> additionalResources,
Credentials credentials, boolean credentialsChanged,
int priority) {
ContainerInfo containerInfo = registeredContainers.get(containerId);
Objects.requireNonNull(containerInfo,
String.format("Cannot register task attempt %s to unknown container %s",
taskSpec.getTaskAttemptID(), containerId));
synchronized (containerInfo) {
if (containerInfo.taskSpec != null) {
throw new TezUncheckedException(
"Cannot register task: " + taskSpec.getTaskAttemptID() + " to container: " +
containerId + " , with pre-existing assignment: " +
containerInfo.taskSpec.getTaskAttemptID());
}
containerInfo.taskSpec = taskSpec;
containerInfo.additionalLRs = additionalResources;
containerInfo.credentials = credentials;
containerInfo.credentialsChanged = credentialsChanged;
containerInfo.taskPulled = false;
ContainerId oldId = attemptToContainerMap.putIfAbsent(taskSpec.getTaskAttemptID(), containerId);
if (oldId != null) {
throw new TezUncheckedException(
"Attempting to register an already registered taskAttempt with id: " +
taskSpec.getTaskAttemptID() + " to containerId: " + containerId +
". Already registered to containerId: " + oldId);
}
}
}
@Override
public void unregisterRunningTaskAttempt(TezTaskAttemptID taskAttemptID, TaskAttemptEndReason endReason, String diagnostics) {
ContainerId containerId = attemptToContainerMap.remove(taskAttemptID);
if(containerId == null) {
LOG.warn("Unregister task attempt: " + taskAttemptID + " from unknown container");
return;
}
ContainerInfo containerInfo = registeredContainers.get(containerId);
if (containerInfo == null) {
LOG.warn("Unregister task attempt: " + taskAttemptID +
" from non-registered container: " + containerId);
return;
}
synchronized (containerInfo) {
containerInfo.reset();
attemptToContainerMap.remove(taskAttemptID);
}
}
@Override
public InetSocketAddress getAddress() {
return address;
}
@Override
public void onVertexStateUpdated(VertexStateUpdate stateUpdate) {
// Empty. Not registering, or expecting any updates.
}
@Override
public void dagComplete(int dagIdentifier) {
// Nothing to do at the moment. Some of the TODOs from TaskAttemptListener apply here.
}
@Override
public Object getMetaInfo() {
return address;
}
protected String getTokenIdentifier() {
return tokenIdentifier;
}
protected Token<JobTokenIdentifier> getSessionToken() {
return sessionToken;
}
public TezTaskUmbilicalProtocol getUmbilical() {
return this.taskUmbilical;
}
private class TezTaskUmbilicalProtocolImpl implements TezTaskUmbilicalProtocol {
@Override
public ContainerTask getTask(ContainerContext containerContext) throws IOException {
ContainerTask task = null;
if (containerContext == null || containerContext.getContainerIdentifier() == null) {
LOG.info("Invalid task request with an empty containerContext or containerId");
task = TASK_FOR_INVALID_JVM;
} else {
ContainerId containerId = ConverterUtils.toContainerId(containerContext
.getContainerIdentifier());
if (LOG.isDebugEnabled()) {
LOG.debug("Container with id: " + containerId + " asked for a task");
}
task = getContainerTask(containerId);
if (task != null && !task.shouldDie()) {
getContext().taskSubmitted(task.getTaskSpec().getTaskAttemptID(), containerId);
getContext().taskStartedRemotely(task.getTaskSpec().getTaskAttemptID());
}
}
if (LOG.isDebugEnabled()) {
LOG.debug("getTask returning task: " + task);
}
return task;
}
@Override
public boolean canCommit(TezTaskAttemptID taskAttemptId) throws IOException {
return getContext().canCommit(taskAttemptId);
}
@Override
public TezHeartbeatResponse heartbeat(TezHeartbeatRequest request) throws IOException,
TezException {
ContainerId containerId = ConverterUtils.toContainerId(request.getContainerIdentifier());
long requestId = request.getRequestId();
if (LOG.isDebugEnabled()) {
LOG.debug("Received heartbeat from container"
+ ", request=" + request);
}
ContainerInfo containerInfo = registeredContainers.get(containerId);
if (containerInfo == null) {
LOG.warn("Received task heartbeat from unknown container with id: " + containerId +
", asking it to die");
TezHeartbeatResponse response = new TezHeartbeatResponse();
response.setLastRequestId(requestId);
response.setShouldDie();
return response;
}
synchronized (containerInfo) {
if (containerInfo.lastRequestId == requestId) {
LOG.warn("Old sequenceId received: " + requestId
+ ", Re-sending last response to client");
return containerInfo.lastResponse;
}
}
TezHeartbeatResponse response = new TezHeartbeatResponse();
TezTaskAttemptID taskAttemptID = request.getCurrentTaskAttemptID();
if (taskAttemptID != null) {
TaskHeartbeatResponse tResponse;
synchronized (containerInfo) {
ContainerId containerIdFromMap = attemptToContainerMap.get(taskAttemptID);
if (containerIdFromMap == null || !containerIdFromMap.equals(containerId)) {
throw new TezException("Attempt " + taskAttemptID
+ " is not recognized for heartbeat");
}
if (containerInfo.lastRequestId + 1 != requestId) {
throw new TezException("Container " + containerId
+ " has invalid request id. Expected: "
+ containerInfo.lastRequestId + 1
+ " and actual: " + requestId);
}
}
TaskHeartbeatRequest tRequest = new TaskHeartbeatRequest(request.getContainerIdentifier(),
request.getCurrentTaskAttemptID(), request.getEvents(), request.getStartIndex(),
request.getPreRoutedStartIndex(), request.getMaxEvents());
tResponse = getContext().heartbeat(tRequest);
response.setEvents(tResponse.getEvents());
response.setNextFromEventId(tResponse.getNextFromEventId());
response.setNextPreRoutedEventId(tResponse.getNextPreRoutedEventId());
}
response.setLastRequestId(requestId);
containerInfo.lastRequestId = requestId;
containerInfo.lastResponse = response;
return response;
}
// TODO Remove this method once we move to the Protobuf RPC engine
@Override
public long getProtocolVersion(String protocol, long clientVersion) throws IOException {
return versionID;
}
// TODO Remove this method once we move to the Protobuf RPC engine
@Override
public ProtocolSignature getProtocolSignature(String protocol, long clientVersion,
int clientMethodsHash) throws IOException {
return ProtocolSignature.getProtocolSignature(this, protocol,
clientVersion, clientMethodsHash);
}
}
private ContainerTask getContainerTask(ContainerId containerId) throws IOException {
ContainerInfo containerInfo = registeredContainers.get(containerId);
ContainerTask task;
if (containerInfo == null) {
if (getContext().isKnownContainer(containerId)) {
LOG.info("Container with id: " + containerId
+ " is valid, but no longer registered, and will be killed");
} else {
LOG.info("Container with id: " + containerId
+ " is invalid and will be killed");
}
task = TASK_FOR_INVALID_JVM;
} else {
synchronized (containerInfo) {
getContext().containerAlive(containerId);
if (containerInfo.taskSpec != null) {
if (!containerInfo.taskPulled) {
containerInfo.taskPulled = true;
task = constructContainerTask(containerInfo);
} else {
if (LOG.isDebugEnabled()) {
LOG.debug("Task " + containerInfo.taskSpec.getTaskAttemptID() +
" already sent to container: " + containerId);
}
task = null;
}
} else {
task = null;
if (LOG.isDebugEnabled()) {
LOG.debug("No task assigned yet for running container: " + containerId);
}
}
}
}
return task;
}
private ContainerTask constructContainerTask(ContainerInfo containerInfo) throws IOException {
return new ContainerTask(containerInfo.taskSpec, false,
convertLocalResourceMap(containerInfo.additionalLRs), containerInfo.credentials,
containerInfo.credentialsChanged);
}
private Map<String, TezLocalResource> convertLocalResourceMap(Map<String, LocalResource> ylrs)
throws IOException {
Map<String, TezLocalResource> tlrs = Maps.newHashMap();
if (ylrs != null) {
for (Map.Entry<String, LocalResource> ylrEntry : ylrs.entrySet()) {
TezLocalResource tlr;
try {
tlr = TezConverterUtils.convertYarnLocalResourceToTez(ylrEntry.getValue());
} catch (URISyntaxException e) {
throw new IOException(e);
}
tlrs.put(ylrEntry.getKey(), tlr);
}
}
return tlrs;
}
protected ContainerInfo getContainerInfo(ContainerId containerId) {
return registeredContainers.get(containerId);
}
protected ContainerId getContainerForAttempt(TezTaskAttemptID taskAttemptId) {
return attemptToContainerMap.get(taskAttemptId);
}
}