blob: fec108bfe1a99fad9a4fcbcaf48a479ce7890dfc [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Translation of user graph's API to the execution graph
"""
from .. import api
from ..executor import base
from . import task as core_task
def build_execution_graph(
task_graph,
execution_graph,
default_executor,
start_cls=core_task.StartWorkflowTask,
end_cls=core_task.EndWorkflowTask,
depends_on=()):
"""
Translates the user graph to the execution graph
:param task_graph: The user's graph
:param workflow_context: The workflow
:param execution_graph: The execution graph that is being built
:param start_cls: internal use
:param end_cls: internal use
:param depends_on: internal use
"""
# Insert start marker
start_task = start_cls(id=_start_graph_suffix(task_graph.id), executor=base.StubTaskExecutor())
_add_task_and_dependencies(execution_graph, start_task, depends_on)
for api_task in task_graph.topological_order(reverse=True):
dependencies = task_graph.get_dependencies(api_task)
operation_dependencies = _get_tasks_from_dependencies(
execution_graph, dependencies, default=[start_task])
if isinstance(api_task, api.task.OperationTask):
operation_task = core_task.OperationTask(api_task, executor=default_executor)
_add_task_and_dependencies(execution_graph, operation_task, operation_dependencies)
elif isinstance(api_task, api.task.WorkflowTask):
# Build the graph recursively while adding start and end markers
build_execution_graph(
task_graph=api_task,
execution_graph=execution_graph,
default_executor=default_executor,
start_cls=core_task.StartSubWorkflowTask,
end_cls=core_task.EndSubWorkflowTask,
depends_on=operation_dependencies
)
elif isinstance(api_task, api.task.StubTask):
stub_task = core_task.StubTask(id=api_task.id, executor=base.StubTaskExecutor())
_add_task_and_dependencies(execution_graph, stub_task, operation_dependencies)
else:
raise RuntimeError('Undefined state')
# Insert end marker
workflow_dependencies = _get_tasks_from_dependencies(
execution_graph,
_get_non_dependency_tasks(task_graph),
default=[start_task])
end_task = end_cls(id=_end_graph_suffix(task_graph.id), executor=base.StubTaskExecutor())
_add_task_and_dependencies(execution_graph, end_task, workflow_dependencies)
def _add_task_and_dependencies(execution_graph, operation_task, operation_dependencies=()):
execution_graph.add_node(operation_task.id, task=operation_task)
for dependency in operation_dependencies:
execution_graph.add_edge(dependency.id, operation_task.id)
def _get_tasks_from_dependencies(execution_graph, dependencies, default=()):
"""
Returns task list from dependencies.
"""
tasks = []
for dependency in dependencies:
if isinstance(dependency, (api.task.OperationTask, api.task.StubTask)):
dependency_id = dependency.id
else:
dependency_id = _end_graph_suffix(dependency.id)
tasks.append(execution_graph.node[dependency_id]['task'])
return tasks or default
def _start_graph_suffix(id):
return '{0}-Start'.format(id)
def _end_graph_suffix(id):
return '{0}-End'.format(id)
def _get_non_dependency_tasks(graph):
for task in graph.tasks:
if len(list(graph.get_dependents(task))) == 0:
yield task