blob: bb72d013c7851fb208f18f7b35f95fba9b76be6c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.nemo.runtime.master.scheduler;
import com.google.common.collect.Sets;
import org.apache.nemo.common.Pair;
import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.exception.UnknownFailureCauseException;
import org.apache.nemo.common.ir.Readable;
import org.apache.nemo.common.ir.edge.executionproperty.MessageIdEdgeProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.IgnoreSchedulingTempDataReceiverProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
import org.apache.nemo.runtime.common.RuntimeIdManager;
import org.apache.nemo.runtime.common.plan.*;
import org.apache.nemo.runtime.common.state.BlockState;
import org.apache.nemo.runtime.common.state.StageState;
import org.apache.nemo.runtime.common.state.TaskState;
import org.apache.nemo.runtime.master.BlockManagerMaster;
import org.apache.nemo.runtime.master.PlanStateManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* Utlity methods regarding schedulers.
*/
public final class BatchSchedulerUtils {
private static final Logger LOG = LoggerFactory.getLogger(BatchSchedulerUtils.class.getName());
/**
* Private constructor for utility class.
*/
private BatchSchedulerUtils() {
}
static Optional<List<Stage>> selectEarliestSchedulableGroup(final List<List<Stage>> sortedScheduleGroups,
final PlanStateManager planStateManager) {
if (sortedScheduleGroups == null) {
return Optional.empty();
}
return sortedScheduleGroups.stream()
.filter(scheduleGroup -> scheduleGroup.stream()
.map(Stage::getId)
.map(planStateManager::getStageState)
.anyMatch(state -> state.equals(StageState.State.INCOMPLETE))) // any incomplete stage in the group
.findFirst(); // selects the one with the smallest scheduling group index.
}
static List<Task> selectSchedulableTasks(final PlanStateManager planStateManager,
final BlockManagerMaster blockManagerMaster,
final Stage stageToSchedule) {
if (stageToSchedule.getPropertyValue(IgnoreSchedulingTempDataReceiverProperty.class).orElse(false)) {
// Ignore ghost stage.
for (final String taskId : planStateManager.getTaskAttemptsToSchedule(stageToSchedule.getId())) {
planStateManager.onTaskStateChanged(taskId, TaskState.State.EXECUTING);
planStateManager.onTaskStateChanged(taskId, TaskState.State.COMPLETE);
}
return Collections.emptyList();
}
final List<StageEdge> stageIncomingEdges =
planStateManager.getPhysicalPlan().getStageDAG().getIncomingEdgesOf(stageToSchedule.getId());
final List<StageEdge> stageOutgoingEdges =
planStateManager.getPhysicalPlan().getStageDAG().getOutgoingEdgesOf(stageToSchedule.getId());
// Create and return tasks.
final List<Map<String, Readable>> vertexIdToReadables = stageToSchedule.getVertexIdToReadables();
final List<String> taskIdsToSchedule = planStateManager.getTaskAttemptsToSchedule(stageToSchedule.getId());
final List<Task> tasks = new ArrayList<>(taskIdsToSchedule.size());
taskIdsToSchedule.forEach(taskId -> {
final Set<String> blockIds = BatchSchedulerUtils.getOutputBlockIds(planStateManager, taskId);
blockManagerMaster.onProducerTaskScheduled(taskId, blockIds);
final int taskIdx = RuntimeIdManager.getIndexFromTaskId(taskId);
tasks.add(new Task(
planStateManager.getPhysicalPlan().getPlanId(),
taskId,
stageToSchedule.getExecutionProperties(),
stageToSchedule.getSerializedIRDAG(),
stageIncomingEdges,
stageOutgoingEdges,
vertexIdToReadables.get(taskIdx)));
});
return tasks;
}
////////////////////////////////////////////////////////////////////// Task state change handlers
/**
* Action after task execution has been completed.
* Note this method should not be invoked when the previous state of the task is ON_HOLD.
*
* @param executorRegistry the registry for available executors.
* @param executorId id of the executor.
* @param taskId the ID of the task completed.
*/
static void onTaskExecutionComplete(final ExecutorRegistry executorRegistry,
final String executorId,
final String taskId) {
LOG.debug("{} completed in {}", taskId, executorId);
executorRegistry.updateExecutor(executorId, (executor, state) -> {
executor.onTaskExecutionComplete(taskId);
return Pair.of(executor, state);
});
}
/**
* Get the target edges of dynamic optimization.
* The edges are annotated with {@link MessageIdEdgeProperty}, which are outgoing edges of
* parents of the stage put on hold.
* <p>
* See {@link org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping.SkewReshapingPass}
* for setting the target edges of dynamic optimization.
*
* @param taskId the task ID that sent stage-level aggregated message for dynamic optimization.
* @return the edges to optimize.
*/
static Set<StageEdge> getEdgesToOptimize(final PlanStateManager planStateManager,
final String taskId) {
final DAG<Stage, StageEdge> stageDag = planStateManager.getPhysicalPlan().getStageDAG();
// Get a stage including the given task
final Stage stagePutOnHold = stageDag.getVertices().stream()
.filter(stage -> stage.getId().equals(RuntimeIdManager.getStageIdFromTaskId(taskId)))
.findFirst()
.orElseThrow(RuntimeException::new);
// Stage put on hold, i.e. stage with vertex containing MessageAggregatorTransform
// should have a parent stage whose outgoing edges contain the target edge of dynamic optimization.
final List<Integer> messageIds = stagePutOnHold.getIRDAG()
.getVertices()
.stream()
.filter(v -> v.getPropertyValue(MessageIdVertexProperty.class).isPresent())
.map(v -> v.getPropertyValue(MessageIdVertexProperty.class).get())
.collect(Collectors.toList());
if (messageIds.size() != 1) {
throw new IllegalStateException("Must be exactly one vertex with the message id: " + messageIds.toString());
}
final int messageId = messageIds.get(0);
final Set<StageEdge> targetEdges = new HashSet<>();
// Get edges with identical MessageIdEdgeProperty (except the put on hold stage)
for (final Stage stage : stageDag.getVertices()) {
final Set<StageEdge> targetEdgesFound = stageDag.getOutgoingEdgesOf(stage).stream()
.filter(candidateEdge -> {
final Optional<HashSet<Integer>> candidateMCId =
candidateEdge.getPropertyValue(MessageIdEdgeProperty.class);
return candidateMCId.isPresent() && candidateMCId.get().contains(messageId);
})
.collect(Collectors.toSet());
targetEdges.addAll(targetEdgesFound);
}
return targetEdges;
}
/**
* Action for after task execution has failed but it's recoverable.
*
* @param executorId the ID of the executor
* @param taskId the ID of the task
* @param failureCause the cause of failure
*/
static void onTaskExecutionFailedRecoverable(final PlanStateManager planStateManager,
final BlockManagerMaster blockManagerMaster,
final ExecutorRegistry executorRegistry,
final String executorId,
final String taskId,
final TaskState.RecoverableTaskFailureCause failureCause) {
LOG.info("{} failed in {} by {}", taskId, executorId, failureCause);
executorRegistry.updateExecutor(executorId, (executor, state) -> {
executor.onTaskExecutionFailed(taskId);
return Pair.of(executor, state);
});
switch (failureCause) {
// Previous task must be re-executed, and incomplete tasks of the belonging stage must be rescheduled.
case INPUT_READ_FAILURE:
// TODO #54: Handle remote data fetch failures
case OUTPUT_WRITE_FAILURE:
blockManagerMaster.onProducerTaskFailed(taskId);
break;
default:
throw new UnknownFailureCauseException(new Throwable("Unknown cause: " + failureCause));
}
retryTasksAndRequiredParents(planStateManager, blockManagerMaster, Collections.singleton(taskId));
}
/**
* Action for after task execution is put on hold.
*
* @param executorId the ID of the executor.
* @param taskId the ID of the task.
*/
static Optional<PhysicalPlan> onTaskExecutionOnHold(final PlanStateManager planStateManager,
final ExecutorRegistry executorRegistry,
final PlanRewriter planRewriter,
final String executorId,
final String taskId) {
LOG.info("{} put on hold in {}", new Object[]{taskId, executorId});
executorRegistry.updateExecutor(executorId, (executor, state) -> {
executor.onTaskExecutionComplete(taskId);
return Pair.of(executor, state);
});
final String stageIdForTaskUponCompletion = RuntimeIdManager.getStageIdFromTaskId(taskId);
final boolean stageComplete =
planStateManager.getStageState(stageIdForTaskUponCompletion).equals(StageState.State.COMPLETE);
final Set<StageEdge> targetEdges = getEdgesToOptimize(planStateManager, taskId);
if (targetEdges.isEmpty()) {
throw new RuntimeException("No edges specified for data skew optimization");
}
if (stageComplete) {
return Optional.of(planRewriter.rewrite(getMessageId(targetEdges)));
} else {
return Optional.empty();
}
}
/**
* Process the RuntimePassMessage.
*
* @param planStateManager to get the edges for the optimization.
* @param planRewriter for rewriting the plan later on.
* @param taskId that generated the message.
* @param data of the message.
*/
public static void onRunTimePassMessage(final PlanStateManager planStateManager, final PlanRewriter planRewriter,
final String taskId, final Object data) {
final Set<StageEdge> targetEdges = BatchSchedulerUtils.getEdgesToOptimize(planStateManager, taskId);
planRewriter.accumulate(getMessageId(targetEdges), targetEdges, data);
}
static int getMessageId(final Set<StageEdge> stageEdges) {
// Here we simply use findFirst() for now...
// TODO #345: Simplify insert
final Set<Integer> messageIds = stageEdges.stream()
.map(edge -> edge.getExecutionProperties()
.get(MessageIdEdgeProperty.class)
.<IllegalArgumentException>orElseThrow(() -> new IllegalArgumentException(edge.getId())))
.findFirst().<IllegalArgumentException>orElseThrow(IllegalArgumentException::new);
// Type casting is needed. See: https://stackoverflow.com/a/40865318
return messageIds.iterator().next();
}
////////////////////////////////////////////////////////////////////// Helper methods
static void retryTasksAndRequiredParents(final PlanStateManager planStateManager,
final BlockManagerMaster blockManagerMaster,
final Set<String> tasks) {
final Set<String> requiredParents =
recursivelyGetParentTasksForLostBlocks(planStateManager, blockManagerMaster, tasks);
final Set<String> tasksToRetry = Sets.union(tasks, requiredParents);
LOG.info("Will be retried: {}", tasksToRetry);
tasksToRetry.forEach(
taskToReExecute -> planStateManager.onTaskStateChanged(taskToReExecute, TaskState.State.SHOULD_RETRY));
}
static Set<String> recursivelyGetParentTasksForLostBlocks(final PlanStateManager planStateManager,
final BlockManagerMaster blockManagerMaster,
final Set<String> children) {
if (children.isEmpty()) {
return Collections.emptySet();
}
final DAG<Stage, StageEdge> stageDAG = planStateManager.getPhysicalPlan().getStageDAG();
final Map<String, StageEdge> idToIncomingEdges = children.stream()
.map(RuntimeIdManager::getStageIdFromTaskId)
.flatMap(stageId -> stageDAG.getIncomingEdgesOf(stageId).stream())
// Ignore duplicates with the mergeFunction in toMap(_,_,mergeFunction)
.collect(Collectors.toMap(StageEdge::getId, Function.identity(), (l, r) -> l));
final Set<String> parentsWithLostBlocks = children.stream()
.flatMap(child -> getInputBlockIds(planStateManager, child).stream()) // child task id -> parent block ids
.map(RuntimeIdManager::getWildCardFromBlockId) // parent block id -> parent block wildcard
.collect(Collectors.toSet()).stream() // remove duplicate wildcards
.filter(parentBlockWildcard -> // lost block = no matching AVAILABLE block attempt for the wildcard
blockManagerMaster.getBlockHandlers(parentBlockWildcard, BlockState.State.AVAILABLE).isEmpty())
.flatMap(lostParentBlockWildcard -> {
// COMPLETE task attempts of the lostParentBlockWildcard must become SHOULD_RETRY
final String inEdgeId = RuntimeIdManager.getRuntimeEdgeIdFromBlockId(lostParentBlockWildcard);
final String parentStageId = idToIncomingEdges.get(inEdgeId).getSrc().getId();
final int parentTaskIndex = RuntimeIdManager.getTaskIndexFromBlockId(lostParentBlockWildcard);
return planStateManager.getAllTaskAttemptsOfStage(parentStageId)
.stream()
.filter(taskId -> RuntimeIdManager.getStageIdFromTaskId(taskId).equals(parentStageId)
&& RuntimeIdManager.getIndexFromTaskId(taskId) == parentTaskIndex)
// COMPLETE -> SHOULD_RETRY
.filter(taskId -> planStateManager.getTaskState(taskId).equals(TaskState.State.COMPLETE));
})
.collect(Collectors.toSet());
// Recursive call
return Sets.union(parentsWithLostBlocks,
recursivelyGetParentTasksForLostBlocks(planStateManager, blockManagerMaster, parentsWithLostBlocks));
}
static Set<String> getOutputBlockIds(final PlanStateManager planStateManager,
final String taskId) {
return planStateManager.getPhysicalPlan().getStageDAG()
.getOutgoingEdgesOf(RuntimeIdManager.getStageIdFromTaskId(taskId))
.stream()
.map(stageEdge -> RuntimeIdManager.generateBlockId(stageEdge.getId(), taskId))
.collect(Collectors.toSet()); // ids of blocks this task will produce
}
static Set<String> getInputBlockIds(final PlanStateManager planStateManager,
final String childTaskId) {
final String stageIdOfChildTask = RuntimeIdManager.getStageIdFromTaskId(childTaskId);
return planStateManager.getPhysicalPlan().getStageDAG().getIncomingEdgesOf(stageIdOfChildTask)
.stream()
.flatMap(inStageEdge -> {
final Set<String> parentTaskIds = planStateManager.getAllTaskAttemptsOfStage(inStageEdge.getSrc().getId());
switch (inStageEdge.getDataCommunicationPattern()) {
case SHUFFLE:
case BROADCAST:
// All of the parent stage's tasks
return parentTaskIds.stream()
.map(parentTaskId -> RuntimeIdManager.generateBlockId(inStageEdge.getId(), parentTaskId));
case ONE_TO_ONE:
// Same-index tasks of the parent stage
return parentTaskIds.stream()
.filter(parentTaskId ->
RuntimeIdManager.getIndexFromTaskId(parentTaskId) == RuntimeIdManager.getIndexFromTaskId(childTaskId))
.map(parentTaskId -> RuntimeIdManager.generateBlockId(inStageEdge.getId(), parentTaskId));
default:
throw new IllegalStateException(inStageEdge.toString());
}
})
.collect(Collectors.toSet());
}
}