blob: ba2f482ee9a39579b146e6a169897fe3062da77e [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Pipeline transformations for the FnApiRunner.
"""
# pytype: skip-file
# mypy: check-untyped-defs
import collections
import copy
import functools
import itertools
import logging
import operator
from builtins import object
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Collection
from typing import Container
from typing import DefaultDict
from typing import Dict
from typing import FrozenSet
from typing import Iterable
from typing import Iterator
from typing import List
from typing import MutableMapping
from typing import NamedTuple
from typing import Optional
from typing import Set
from typing import Tuple
from typing import TypeVar
from typing import Union
from apache_beam import coders
from apache_beam.internal import pickler
from apache_beam.portability import common_urns
from apache_beam.portability import python_urns
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.worker import bundle_processor
from apache_beam.transforms import combiners
from apache_beam.transforms import core
from apache_beam.utils import proto_utils
from apache_beam.utils import timestamp
if TYPE_CHECKING:
from apache_beam.runners.portability.fn_api_runner.execution import ListBuffer
from apache_beam.runners.portability.fn_api_runner.execution import PartitionableBuffer
T = TypeVar('T')
# This module is experimental. No backwards-compatibility guarantees.
_LOGGER = logging.getLogger(__name__)
KNOWN_COMPOSITES = frozenset([
common_urns.primitives.GROUP_BY_KEY.urn,
common_urns.composites.COMBINE_PER_KEY.urn,
common_urns.primitives.PAR_DO.urn, # After SDF expansion.
])
COMBINE_URNS = frozenset([
common_urns.composites.COMBINE_PER_KEY.urn,
])
PAR_DO_URNS = frozenset([
common_urns.primitives.PAR_DO.urn,
common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn,
common_urns.sdf_components.SPLIT_AND_SIZE_RESTRICTIONS.urn,
common_urns.sdf_components.PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn,
])
IMPULSE_BUFFER = b'impulse'
# TimerFamilyId is identified by transform name + timer family
# TODO(pabloem): Rename this type to express this name is unique per pipeline.
TimerFamilyId = Tuple[str, str]
BufferId = bytes
# SideInputId is identified by a consumer ParDo + tag.
SideInputId = Tuple[str, str]
SideInputAccessPattern = beam_runner_api_pb2.FunctionSpec
# A map from a PCollection coder ID to a Safe Coder ID
# A safe coder is a coder that can be used on the runner-side of the FnApi.
# A safe coder receives a byte string, and returns a type that can be
# understood by the runner when deserializing.
SafeCoderMapping = Dict[str, str]
# DataSideInput maps SideInputIds to a tuple of the encoded bytes of the side
# input content, and a payload specification regarding the type of side input
# (MultiMap / Iterable).
DataSideInput = Dict[SideInputId, Tuple[bytes, SideInputAccessPattern]]
DataOutput = Dict[str, BufferId]
# A map of [Transform ID, Timer Family ID] to [Buffer ID, Time Domain for timer]
# The time domain comes from beam_runner_api_pb2.TimeDomain. It may be
# EVENT_TIME or PROCESSING_TIME.
OutputTimers = MutableMapping[TimerFamilyId, Tuple[BufferId, Any]]
# A map of [Transform ID, Timer Family ID] to [Buffer CONTENTS, Timestamp]
OutputTimerData = MutableMapping[TimerFamilyId,
Tuple['PartitionableBuffer',
timestamp.Timestamp]]
BundleProcessResult = Tuple[beam_fn_api_pb2.InstructionResponse,
List[beam_fn_api_pb2.ProcessBundleSplitResponse]]
# TODO(pabloem): Change tha name to a more representative one
class DataInput(NamedTuple):
data: MutableMapping[str, 'PartitionableBuffer']
timers: MutableMapping[TimerFamilyId, 'PartitionableBuffer']
class Stage(object):
"""A set of Transforms that can be sent to the worker for processing."""
def __init__(
self,
name, # type: str
transforms, # type: List[beam_runner_api_pb2.PTransform]
downstream_side_inputs=None, # type: Optional[FrozenSet[str]]
must_follow=frozenset(), # type: FrozenSet[Stage]
parent=None, # type: Optional[str]
environment=None, # type: Optional[str]
forced_root=False):
self.name = name
self.transforms = transforms
self.downstream_side_inputs = downstream_side_inputs
self.must_follow = must_follow
self.timers = set() # type: Set[TimerFamilyId]
self.parent = parent
if environment is None:
environment = functools.reduce(
self._merge_environments,
(self._extract_environment(t) for t in transforms))
self.environment = environment
self.forced_root = forced_root
def __repr__(self):
must_follow = ', '.join(prev.name for prev in self.must_follow)
if self.downstream_side_inputs is None:
downstream_side_inputs = '<unknown>'
else:
downstream_side_inputs = ', '.join(
str(si) for si in self.downstream_side_inputs)
return "%s\n %s\n must follow: %s\n downstream_side_inputs: %s" % (
self.name,
'\n'.join([
"%s:%s" % (transform.unique_name, transform.spec.urn)
for transform in self.transforms
]),
must_follow,
downstream_side_inputs)
@staticmethod
def _extract_environment(transform):
# type: (beam_runner_api_pb2.PTransform) -> Optional[str]
environment = transform.environment_id
return environment if environment else None
@staticmethod
def _merge_environments(env1, env2):
# type: (Optional[str], Optional[str]) -> Optional[str]
if env1 is None:
return env2
elif env2 is None:
return env1
else:
if env1 != env2:
raise ValueError(
"Incompatible environments: '%s' != '%s'" %
(str(env1).replace('\n', ' '), str(env2).replace('\n', ' ')))
return env1
def can_fuse(self, consumer, context):
# type: (Stage, TransformContext) -> bool
try:
self._merge_environments(self.environment, consumer.environment)
except ValueError:
return False
def no_overlap(a, b):
return not a or not b or not a.intersection(b)
return (
not consumer.forced_root and not self in consumer.must_follow and
self.is_all_sdk_urns(context) and consumer.is_all_sdk_urns(context) and
no_overlap(self.downstream_side_inputs, consumer.side_inputs()))
def fuse(self, other, context):
# type: (Stage, TransformContext) -> Stage
return Stage(
"(%s)+(%s)" % (self.name, other.name),
self.transforms + other.transforms,
union(self.downstream_side_inputs, other.downstream_side_inputs),
union(self.must_follow, other.must_follow),
environment=self._merge_environments(
self.environment, other.environment),
parent=_parent_for_fused_stages([self, other], context),
forced_root=self.forced_root or other.forced_root)
def is_runner_urn(self, context):
# type: (TransformContext) -> bool
return any(
transform.spec.urn in context.known_runner_urns
for transform in self.transforms)
def is_all_sdk_urns(self, context):
def is_sdk_transform(transform):
# Execute multi-input flattens in the runner.
if transform.spec.urn == common_urns.primitives.FLATTEN.urn and len(
transform.inputs) > 1:
return False
else:
return transform.spec.urn not in context.runner_only_urns
return all(is_sdk_transform(transform) for transform in self.transforms)
def is_stateful(self):
for transform in self.transforms:
if transform.spec.urn in PAR_DO_URNS:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
if payload.state_specs or payload.timer_family_specs:
return True
return False
def side_inputs(self):
# type: () -> Iterator[str]
for transform in self.transforms:
if transform.spec.urn in PAR_DO_URNS:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
for side_input in payload.side_inputs:
yield transform.inputs[side_input]
def has_as_main_input(self, pcoll):
for transform in self.transforms:
if transform.spec.urn in PAR_DO_URNS:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
local_side_inputs = payload.side_inputs
else:
local_side_inputs = {}
for local_id, pipeline_id in transform.inputs.items():
if pcoll == pipeline_id and local_id not in local_side_inputs:
return True
def deduplicate_read(self):
# type: () -> None
seen_pcolls = set() # type: Set[str]
new_transforms = []
for transform in self.transforms:
if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
pcoll = only_element(list(transform.outputs.items()))[1]
if pcoll in seen_pcolls:
continue
seen_pcolls.add(pcoll)
new_transforms.append(transform)
self.transforms = new_transforms
def executable_stage_transform(
self,
known_runner_urns, # type: FrozenSet[str]
all_consumers,
components # type: beam_runner_api_pb2.Components
):
# type: (...) -> beam_runner_api_pb2.PTransform
if (len(self.transforms) == 1 and
self.transforms[0].spec.urn in known_runner_urns):
result = copy.copy(self.transforms[0])
del result.subtransforms[:]
return result
else:
all_inputs = set(
pcoll for t in self.transforms for pcoll in t.inputs.values())
all_outputs = set(
pcoll for t in self.transforms for pcoll in t.outputs.values())
internal_transforms = set(id(t) for t in self.transforms)
external_outputs = [
pcoll for pcoll in all_outputs
if all_consumers[pcoll] - internal_transforms
]
stage_components = beam_runner_api_pb2.Components()
stage_components.CopyFrom(components)
# Only keep the PCollections referenced in this stage.
stage_components.pcollections.clear()
for pcoll_id in all_inputs.union(all_outputs):
stage_components.pcollections[pcoll_id].CopyFrom(
components.pcollections[pcoll_id])
# Only keep the transforms in this stage.
# Also gather up payload data as we iterate over the transforms.
stage_components.transforms.clear()
main_inputs = set() # type: Set[str]
side_inputs = []
user_states = []
timers = []
for ix, transform in enumerate(self.transforms):
transform_id = 'transform_%d' % ix
if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
for tag in payload.side_inputs.keys():
side_inputs.append(
beam_runner_api_pb2.ExecutableStagePayload.SideInputId(
transform_id=transform_id, local_name=tag))
for tag in payload.state_specs.keys():
user_states.append(
beam_runner_api_pb2.ExecutableStagePayload.UserStateId(
transform_id=transform_id, local_name=tag))
for tag in payload.timer_family_specs.keys():
timers.append(
beam_runner_api_pb2.ExecutableStagePayload.TimerId(
transform_id=transform_id, local_name=tag))
main_inputs.update(
pcoll_id for tag,
pcoll_id in transform.inputs.items()
if tag not in payload.side_inputs)
else:
main_inputs.update(transform.inputs.values())
stage_components.transforms[transform_id].CopyFrom(transform)
main_input_id = only_element(main_inputs - all_outputs)
named_inputs = dict({
'%s:%s' % (side.transform_id, side.local_name):
stage_components.transforms[side.transform_id].inputs[side.local_name]
for side in side_inputs
},
main_input=main_input_id)
# at this point we should have resolved an environment, as the key of
# components.environments cannot be None.
assert self.environment is not None
exec_payload = beam_runner_api_pb2.ExecutableStagePayload(
environment=components.environments[self.environment],
input=main_input_id,
outputs=external_outputs,
transforms=stage_components.transforms.keys(),
components=stage_components,
side_inputs=side_inputs,
user_states=user_states,
timers=timers)
return beam_runner_api_pb2.PTransform(
unique_name=unique_name(None, self.name),
spec=beam_runner_api_pb2.FunctionSpec(
urn='beam:runner:executable_stage:v1',
payload=exec_payload.SerializeToString()),
inputs=named_inputs,
outputs={
'output_%d' % ix: pcoll
for ix,
pcoll in enumerate(external_outputs)
},
)
def memoize_on_instance(f):
missing = object()
def wrapper(self, *args):
try:
cache = getattr(self, '_cache_%s' % f.__name__)
except AttributeError:
cache = {}
setattr(self, '_cache_%s' % f.__name__, cache)
result = cache.get(args, missing)
if result is missing:
result = cache[args] = f(self, *args)
return result
return wrapper
class TransformContext(object):
_COMMON_CODER_URNS = set(
value.urn for (key, value) in common_urns.coders.__dict__.items()
if not key.startswith('_')
# Length prefix Rows rather than re-coding them.
) - set([common_urns.coders.ROW.urn])
_REQUIRED_CODER_URNS = set([
common_urns.coders.WINDOWED_VALUE.urn,
# For impulse.
common_urns.coders.BYTES.urn,
common_urns.coders.GLOBAL_WINDOW.urn,
# For GBK.
common_urns.coders.KV.urn,
common_urns.coders.ITERABLE.urn,
# For SDF.
common_urns.coders.DOUBLE.urn,
# For timers.
common_urns.coders.TIMER.urn,
# For everything else.
common_urns.coders.LENGTH_PREFIX.urn,
common_urns.coders.CUSTOM_WINDOW.urn,
])
def __init__(
self,
components, # type: beam_runner_api_pb2.Components
known_runner_urns, # type: FrozenSet[str]
use_state_iterables=False,
is_drain=False):
self.components = components
self.known_runner_urns = known_runner_urns
self.runner_only_urns = known_runner_urns - frozenset(
[common_urns.primitives.FLATTEN.urn])
self._known_coder_urns = set.union(
# Those which are required.
self._REQUIRED_CODER_URNS,
# Those common coders which are understood by all environments.
self._COMMON_CODER_URNS.intersection(
*(
set(env.capabilities)
for env in self.components.environments.values())))
self.use_state_iterables = use_state_iterables
self.is_drain = is_drain
# ok to pass None for context because BytesCoder has no components
coder_proto = coders.BytesCoder().to_runner_api(
None) # type: ignore[arg-type]
self.bytes_coder_id = self.add_or_get_coder_id(coder_proto, 'bytes_coder')
self.safe_coders: SafeCoderMapping = {
self.bytes_coder_id: self.bytes_coder_id
}
# A map of PCollection ID to Coder ID.
self.data_channel_coders = {} # type: Dict[str, str]
def add_or_get_coder_id(
self,
coder_proto, # type: beam_runner_api_pb2.Coder
coder_prefix='coder'):
# type: (...) -> str
for coder_id, coder in self.components.coders.items():
if coder == coder_proto:
return coder_id
new_coder_id = unique_name(self.components.coders, coder_prefix)
self.components.coders[new_coder_id].CopyFrom(coder_proto)
return new_coder_id
def add_data_channel_coder(self, pcoll_id):
pcoll = self.components.pcollections[pcoll_id]
proto = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.coders.WINDOWED_VALUE.urn),
component_coder_ids=[
pcoll.coder_id,
self.components.windowing_strategies[
pcoll.windowing_strategy_id].window_coder_id
])
self.data_channel_coders[pcoll_id] = self.maybe_length_prefixed_coder(
self.add_or_get_coder_id(proto, pcoll.coder_id + '_windowed'))
@memoize_on_instance
def with_state_iterables(self, coder_id):
# type: (str) -> str
coder = self.components.coders[coder_id]
if coder.spec.urn == common_urns.coders.ITERABLE.urn:
new_coder_id = unique_name(
self.components.coders, coder_id + '_state_backed')
new_coder = self.components.coders[new_coder_id]
new_coder.CopyFrom(coder)
new_coder.spec.urn = common_urns.coders.STATE_BACKED_ITERABLE.urn
new_coder.spec.payload = b'1'
new_coder.component_coder_ids[0] = self.with_state_iterables(
coder.component_coder_ids[0])
return new_coder_id
else:
new_component_ids = [
self.with_state_iterables(c) for c in coder.component_coder_ids
]
if new_component_ids == coder.component_coder_ids:
return coder_id
else:
new_coder_id = unique_name(
self.components.coders, coder_id + '_state_backed')
self.components.coders[new_coder_id].CopyFrom(
beam_runner_api_pb2.Coder(
spec=coder.spec, component_coder_ids=new_component_ids))
return new_coder_id
@memoize_on_instance
def maybe_length_prefixed_coder(self, coder_id):
# type: (str) -> str
if coder_id in self.safe_coders:
return coder_id
(maybe_length_prefixed_id,
safe_id) = self.maybe_length_prefixed_and_safe_coder(coder_id)
self.safe_coders[maybe_length_prefixed_id] = safe_id
return maybe_length_prefixed_id
@memoize_on_instance
def maybe_length_prefixed_and_safe_coder(self, coder_id):
# type: (str) -> Tuple[str, str]
coder = self.components.coders[coder_id]
if coder.spec.urn == common_urns.coders.LENGTH_PREFIX.urn:
return coder_id, self.bytes_coder_id
elif coder.spec.urn in self._known_coder_urns:
new_component_ids = [
self.maybe_length_prefixed_coder(c) for c in coder.component_coder_ids
]
if new_component_ids == coder.component_coder_ids:
new_coder_id = coder_id
else:
new_coder_id = unique_name(
self.components.coders, coder_id + '_length_prefixed')
self.components.coders[new_coder_id].CopyFrom(
beam_runner_api_pb2.Coder(
spec=coder.spec, component_coder_ids=new_component_ids))
safe_component_ids = [self.safe_coders[c] for c in new_component_ids]
if safe_component_ids == coder.component_coder_ids:
safe_coder_id = coder_id
else:
safe_coder_id = unique_name(self.components.coders, coder_id + '_safe')
self.components.coders[safe_coder_id].CopyFrom(
beam_runner_api_pb2.Coder(
spec=coder.spec, component_coder_ids=safe_component_ids))
return new_coder_id, safe_coder_id
else:
new_coder_id = unique_name(
self.components.coders, coder_id + '_length_prefixed')
self.components.coders[new_coder_id].CopyFrom(
beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.coders.LENGTH_PREFIX.urn),
component_coder_ids=[coder_id]))
return new_coder_id, self.bytes_coder_id
def length_prefix_pcoll_coders(self, pcoll_id):
# type: (str) -> None
self.components.pcollections[pcoll_id].coder_id = (
self.maybe_length_prefixed_coder(
self.components.pcollections[pcoll_id].coder_id))
@memoize_on_instance
def parents_map(self):
return {
child: parent
for (parent, transform) in self.components.transforms.items()
for child in transform.subtransforms
}
def leaf_transform_stages(
root_ids, # type: Iterable[str]
components, # type: beam_runner_api_pb2.Components
parent=None, # type: Optional[str]
known_composites=KNOWN_COMPOSITES # type: FrozenSet[str]
):
# type: (...) -> Iterator[Stage]
for root_id in root_ids:
root = components.transforms[root_id]
if root.spec.urn in known_composites:
yield Stage(root_id, [root], parent=parent)
elif not root.subtransforms:
# Make sure its outputs are not a subset of its inputs.
if set(root.outputs.values()) - set(root.inputs.values()):
yield Stage(root_id, [root], parent=parent)
else:
for stage in leaf_transform_stages(root.subtransforms,
components,
root_id,
known_composites):
yield stage
def pipeline_from_stages(
pipeline_proto, # type: beam_runner_api_pb2.Pipeline
stages, # type: Iterable[Stage]
known_runner_urns, # type: FrozenSet[str]
partial # type: bool
):
# type: (...) -> beam_runner_api_pb2.Pipeline
# In case it was a generator that mutates components as it
# produces outputs (as is the case with most transformations).
stages = list(stages)
new_proto = beam_runner_api_pb2.Pipeline()
new_proto.CopyFrom(pipeline_proto)
components = new_proto.components
components.transforms.clear()
components.pcollections.clear()
roots = set()
parents = {
child: parent
for parent,
proto in pipeline_proto.components.transforms.items()
for child in proto.subtransforms
}
def copy_output_pcollections(transform):
for pcoll_id in transform.outputs.values():
components.pcollections[pcoll_id].CopyFrom(
pipeline_proto.components.pcollections[pcoll_id])
def add_parent(child, parent):
if parent is None:
roots.add(child)
else:
if (parent not in components.transforms and
parent in pipeline_proto.components.transforms):
components.transforms[parent].CopyFrom(
pipeline_proto.components.transforms[parent])
copy_output_pcollections(components.transforms[parent])
del components.transforms[parent].subtransforms[:]
# Ensure that child is the last item in the parent's subtransforms.
# If the stages were previously sorted into topological order using
# sort_stages, this ensures that the parent transforms are also
# added in topological order.
if child in components.transforms[parent].subtransforms:
components.transforms[parent].subtransforms.remove(child)
components.transforms[parent].subtransforms.append(child)
add_parent(parent, parents.get(parent))
def copy_subtransforms(transform):
for subtransform_id in transform.subtransforms:
if subtransform_id not in pipeline_proto.components.transforms:
raise RuntimeError(
'Could not find subtransform to copy: ' + subtransform_id)
subtransform = pipeline_proto.components.transforms[subtransform_id]
components.transforms[subtransform_id].CopyFrom(subtransform)
copy_output_pcollections(components.transforms[subtransform_id])
copy_subtransforms(subtransform)
all_consumers = collections.defaultdict(
set) # type: DefaultDict[str, Set[int]]
for stage in stages:
for transform in stage.transforms:
for pcoll in transform.inputs.values():
all_consumers[pcoll].add(id(transform))
for stage in stages:
if partial:
transform = only_element(stage.transforms)
copy_subtransforms(transform)
else:
transform = stage.executable_stage_transform(
known_runner_urns, all_consumers, pipeline_proto.components)
transform_id = unique_name(components.transforms, stage.name)
components.transforms[transform_id].CopyFrom(transform)
copy_output_pcollections(transform)
add_parent(transform_id, stage.parent)
del new_proto.root_transform_ids[:]
new_proto.root_transform_ids.extend(roots)
return new_proto
def create_and_optimize_stages(
pipeline_proto, # type: beam_runner_api_pb2.Pipeline
phases,
known_runner_urns, # type: FrozenSet[str]
use_state_iterables=False,
is_drain=False):
# type: (...) -> Tuple[TransformContext, List[Stage]]
"""Create a set of stages given a pipeline proto, and set of optimizations.
Args:
pipeline_proto (beam_runner_api_pb2.Pipeline): A pipeline defined by a user.
phases (callable): Each phase identifies a specific transformation to be
applied to the pipeline graph. Existing phases are defined in this file,
and receive a list of stages, and a pipeline context. Some available
transformations are ``lift_combiners``, ``expand_sdf``, ``expand_gbk``,
etc.
Returns:
A tuple with a pipeline context, and a list of stages (i.e. an optimized
graph).
"""
pipeline_context = TransformContext(
pipeline_proto.components,
known_runner_urns,
use_state_iterables=use_state_iterables,
is_drain=is_drain)
# Initial set of stages are singleton leaf transforms.
stages = list(
leaf_transform_stages(
pipeline_proto.root_transform_ids,
pipeline_proto.components,
known_composites=union(known_runner_urns, KNOWN_COMPOSITES)))
# Apply each phase in order.
for phase in phases:
_LOGGER.info('%s %s %s', '=' * 20, phase, '=' * 20)
stages = list(phase(stages, pipeline_context))
_LOGGER.debug('%s %s' % (len(stages), [len(s.transforms) for s in stages]))
_LOGGER.debug('Stages: %s', [str(s) for s in stages])
# Return the (possibly mutated) context and ordered set of stages.
return pipeline_context, stages
def optimize_pipeline(
pipeline_proto, # type: beam_runner_api_pb2.Pipeline
phases,
known_runner_urns, # type: FrozenSet[str]
partial=False,
**kwargs):
unused_context, stages = create_and_optimize_stages(
pipeline_proto,
phases,
known_runner_urns,
**kwargs)
return pipeline_from_stages(
pipeline_proto, stages, known_runner_urns, partial)
# Optimization stages.
def annotate_downstream_side_inputs(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> Iterable[Stage]
"""Annotate each stage with fusion-prohibiting information.
Each stage is annotated with the (transitive) set of pcollections that
depend on this stage that are also used later in the pipeline as a
side input.
While theoretically this could result in O(n^2) annotations, the size of
each set is bounded by the number of side inputs (typically much smaller
than the number of total nodes) and the number of *distinct* side-input
sets is also generally small (and shared due to the use of union
defined above).
This representation is also amenable to simple recomputation on fusion.
"""
consumers = collections.defaultdict(
list) # type: DefaultDict[str, List[Stage]]
def get_all_side_inputs():
# type: () -> Set[str]
all_side_inputs = set() # type: Set[str]
for stage in stages:
for transform in stage.transforms:
for input in transform.inputs.values():
consumers[input].append(stage)
for si in stage.side_inputs():
all_side_inputs.add(si)
return all_side_inputs
all_side_inputs = frozenset(get_all_side_inputs())
downstream_side_inputs_by_stage = {} # type: Dict[Stage, FrozenSet[str]]
def compute_downstream_side_inputs(stage):
# type: (Stage) -> FrozenSet[str]
if stage not in downstream_side_inputs_by_stage:
downstream_side_inputs = frozenset() # type: FrozenSet[str]
for transform in stage.transforms:
for output in transform.outputs.values():
if output in all_side_inputs:
downstream_side_inputs = union(
downstream_side_inputs, frozenset([output]))
for consumer in consumers[output]:
downstream_side_inputs = union(
downstream_side_inputs,
compute_downstream_side_inputs(consumer))
downstream_side_inputs_by_stage[stage] = downstream_side_inputs
return downstream_side_inputs_by_stage[stage]
for stage in stages:
stage.downstream_side_inputs = compute_downstream_side_inputs(stage)
return stages
def annotate_stateful_dofns_as_roots(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> Iterable[Stage]
for stage in stages:
for transform in stage.transforms:
if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
pardo_payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
if pardo_payload.state_specs or pardo_payload.timer_family_specs:
stage.forced_root = True
yield stage
def fix_side_input_pcoll_coders(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> Iterable[Stage]
"""Length prefix side input PCollection coders.
"""
for stage in stages:
for si in stage.side_inputs():
pipeline_context.length_prefix_pcoll_coders(si)
return stages
def _group_stages_by_key(stages, get_stage_key):
grouped_stages = collections.defaultdict(list)
stages_with_none_key = []
for stage in stages:
stage_key = get_stage_key(stage)
if stage_key is None:
stages_with_none_key.append(stage)
else:
grouped_stages[stage_key].append(stage)
return (grouped_stages, stages_with_none_key)
def _group_stages_with_limit(stages, get_limit):
# type: (Iterable[Stage], Callable[[str], int]) -> Iterable[Collection[Stage]]
stages_with_limit = [(stage, get_limit(stage.name)) for stage in stages]
group: List[Stage] = []
group_limit = 0
for stage, limit in sorted(stages_with_limit, key=operator.itemgetter(1)):
if limit < 1:
raise Exception(
'expected get_limit to return an integer >= 1, '
'instead got: %d for stage: %s' % (limit, stage))
if not group:
group_limit = limit
assert len(group) < group_limit
group.append(stage)
if len(group) >= group_limit:
yield group
group = []
if group:
yield group
def _remap_input_pcolls(transform, pcoll_id_remap):
for input_key in list(transform.inputs.keys()):
if transform.inputs[input_key] in pcoll_id_remap:
transform.inputs[input_key] = pcoll_id_remap[transform.inputs[input_key]]
def _make_pack_name(names):
"""Return the packed Transform or Stage name.
The output name will contain the input names' common prefix, the infix
'/Packed', and the input names' suffixes in square brackets.
For example, if the input names are 'a/b/c1/d1' and 'a/b/c2/d2, then
the output name is 'a/b/Packed[c1_d1, c2_d2]'.
"""
assert names
tokens_in_names = [name.split('/') for name in names]
common_prefix_tokens = []
# Find the longest common prefix of tokens.
while True:
first_token_in_names = set()
for tokens in tokens_in_names:
if not tokens:
break
first_token_in_names.add(tokens[0])
if len(first_token_in_names) != 1:
break
common_prefix_tokens.append(next(iter(first_token_in_names)))
for tokens in tokens_in_names:
tokens.pop(0)
common_prefix_tokens.append('Packed')
common_prefix = '/'.join(common_prefix_tokens)
suffixes = ['_'.join(tokens) for tokens in tokens_in_names]
return '%s[%s]' % (common_prefix, ', '.join(suffixes))
def _eliminate_common_key_with_none(stages, context, can_pack=lambda s: True):
# type: (Iterable[Stage], TransformContext, Callable[[str], Union[bool, int]]) -> Iterable[Stage]
"""Runs common subexpression elimination for sibling KeyWithNone stages.
If multiple KeyWithNone stages share a common input, then all but one stages
will be eliminated along with their output PCollections. Transforms that
originally read input from the output PCollection of the eliminated
KeyWithNone stages will be remapped to read input from the output PCollection
of the remaining KeyWithNone stage.
"""
# Partition stages by whether they are eligible for common KeyWithNone
# elimination, and group eligible KeyWithNone stages by parent and
# environment.
def get_stage_key(stage):
if len(stage.transforms) == 1 and can_pack(stage.name):
transform = only_transform(stage.transforms)
if (transform.spec.urn == common_urns.primitives.PAR_DO.urn and
len(transform.inputs) == 1 and len(transform.outputs) == 1):
pardo_payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
if pardo_payload.do_fn.urn == python_urns.KEY_WITH_NONE_DOFN:
return (only_element(transform.inputs.values()), stage.environment)
return None
grouped_eligible_stages, ineligible_stages = _group_stages_by_key(
stages, get_stage_key)
# Eliminate stages and build the PCollection remapping dictionary.
pcoll_id_remap = {}
remaining_stages = []
for sibling_stages in grouped_eligible_stages.values():
if len(sibling_stages) > 1:
output_pcoll_ids = [
only_element(stage.transforms[0].outputs.values())
for stage in sibling_stages
]
parent = _parent_for_fused_stages(sibling_stages, context)
for to_delete_pcoll_id in output_pcoll_ids[1:]:
pcoll_id_remap[to_delete_pcoll_id] = output_pcoll_ids[0]
del context.components.pcollections[to_delete_pcoll_id]
sibling_stages[0].parent = parent
sibling_stages[0].name = _make_pack_name(
stage.name for stage in sibling_stages)
only_transform(
sibling_stages[0].transforms).unique_name = _make_pack_name(
only_transform(stage.transforms).unique_name
for stage in sibling_stages)
remaining_stages.append(sibling_stages[0])
# Remap all transforms in components.
for transform in context.components.transforms.values():
_remap_input_pcolls(transform, pcoll_id_remap)
# Yield stages while remapping input PCollections if needed.
stages_to_yield = itertools.chain(ineligible_stages, remaining_stages)
for stage in stages_to_yield:
transform = only_transform(stage.transforms)
_remap_input_pcolls(transform, pcoll_id_remap)
yield stage
_DEFAULT_PACK_COMBINERS_LIMIT = 128
def pack_per_key_combiners(stages, context, can_pack=lambda s: True):
# type: (Iterable[Stage], TransformContext, Callable[[str], Union[bool, int]]) -> Iterator[Stage]
"""Packs sibling CombinePerKey stages into a single CombinePerKey.
If CombinePerKey stages have a common input, one input each, and one output
each, pack the stages into a single stage that runs all CombinePerKeys and
outputs resulting tuples to a new PCollection. A subsequent stage unpacks
tuples from this PCollection and sends them to the original output
PCollections.
"""
class _UnpackFn(core.DoFn):
"""A DoFn that unpacks a packed to multiple tagged outputs.
Example:
tags = (T1, T2, ...)
input = (K, (V1, V2, ...))
output = TaggedOutput(T1, (K, V1)), TaggedOutput(T2, (K, V1)), ...
"""
def __init__(self, tags):
self._tags = tags
def process(self, element):
key, values = element
return [
core.pvalue.TaggedOutput(tag, (key, value)) for tag,
value in zip(self._tags, values)
]
def _get_fallback_coder_id():
return context.add_or_get_coder_id(
# passing None works here because there are no component coders
coders.registry.get_coder(object).to_runner_api(None)) # type: ignore[arg-type]
def _get_component_coder_id_from_kv_coder(coder, index):
assert index < 2
if coder.spec.urn == common_urns.coders.KV.urn and len(
coder.component_coder_ids) == 2:
return coder.component_coder_ids[index]
return _get_fallback_coder_id()
def _get_key_coder_id_from_kv_coder(coder):
return _get_component_coder_id_from_kv_coder(coder, 0)
def _get_value_coder_id_from_kv_coder(coder):
return _get_component_coder_id_from_kv_coder(coder, 1)
def _try_fuse_stages(a, b):
if a.can_fuse(b, context):
return a.fuse(b, context)
else:
raise ValueError
def _get_limit(stage_name):
result = can_pack(stage_name)
if result is True:
return _DEFAULT_PACK_COMBINERS_LIMIT
else:
return int(result)
# Partition stages by whether they are eligible for CombinePerKey packing
# and group eligible CombinePerKey stages by parent and environment.
def get_stage_key(stage):
if (len(stage.transforms) == 1 and can_pack(stage.name) and
stage.environment is not None and python_urns.PACKED_COMBINE_FN in
context.components.environments[stage.environment].capabilities):
transform = only_transform(stage.transforms)
if (transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn and
len(transform.inputs) == 1 and len(transform.outputs) == 1):
combine_payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.CombinePayload)
if combine_payload.combine_fn.urn == python_urns.PICKLED_COMBINE_FN:
return (only_element(transform.inputs.values()), stage.environment)
return None
grouped_eligible_stages, ineligible_stages = _group_stages_by_key(
stages, get_stage_key)
for stage in ineligible_stages:
yield stage
grouped_packable_stages = [(stage_key, subgrouped_stages) for stage_key,
grouped_stages in grouped_eligible_stages.items()
for subgrouped_stages in _group_stages_with_limit(
grouped_stages, _get_limit)]
for stage_key, packable_stages in grouped_packable_stages:
input_pcoll_id, _ = stage_key
try:
if not len(packable_stages) > 1:
raise ValueError('Only one stage in this group: Skipping stage packing')
# Fused stage is used as template and is not yielded.
fused_stage = functools.reduce(_try_fuse_stages, packable_stages)
except ValueError:
# Skip packing stages in this group.
# Yield the stages unmodified, and then continue to the next group.
for stage in packable_stages:
yield stage
continue
transforms = [only_transform(stage.transforms) for stage in packable_stages]
combine_payloads = [
proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.CombinePayload)
for transform in transforms
]
output_pcoll_ids = [
only_element(transform.outputs.values()) for transform in transforms
]
# Build accumulator coder for (acc1, acc2, ...)
accumulator_coder_ids = [
combine_payload.accumulator_coder_id
for combine_payload in combine_payloads
]
tuple_accumulator_coder_id = context.add_or_get_coder_id(
beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(urn=python_urns.TUPLE_CODER),
component_coder_ids=accumulator_coder_ids))
# Build packed output coder for (key, (out1, out2, ...))
input_kv_coder_id = context.components.pcollections[input_pcoll_id].coder_id
key_coder_id = _get_key_coder_id_from_kv_coder(
context.components.coders[input_kv_coder_id])
output_kv_coder_ids = [
context.components.pcollections[output_pcoll_id].coder_id
for output_pcoll_id in output_pcoll_ids
]
output_value_coder_ids = [
_get_value_coder_id_from_kv_coder(
context.components.coders[output_kv_coder_id])
for output_kv_coder_id in output_kv_coder_ids
]
pack_output_value_coder = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(urn=python_urns.TUPLE_CODER),
component_coder_ids=output_value_coder_ids)
pack_output_value_coder_id = context.add_or_get_coder_id(
pack_output_value_coder)
pack_output_kv_coder = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(urn=common_urns.coders.KV.urn),
component_coder_ids=[key_coder_id, pack_output_value_coder_id])
pack_output_kv_coder_id = context.add_or_get_coder_id(pack_output_kv_coder)
pack_stage_name = _make_pack_name([stage.name for stage in packable_stages])
pack_transform_name = _make_pack_name([
only_transform(stage.transforms).unique_name
for stage in packable_stages
])
pack_pcoll_id = unique_name(context.components.pcollections, 'pcollection')
input_pcoll = context.components.pcollections[input_pcoll_id]
context.components.pcollections[pack_pcoll_id].CopyFrom(
beam_runner_api_pb2.PCollection(
unique_name=pack_transform_name + '/Pack.out',
coder_id=pack_output_kv_coder_id,
windowing_strategy_id=input_pcoll.windowing_strategy_id,
is_bounded=input_pcoll.is_bounded))
# Set up Pack stage.
# TODO(https://github.com/apache/beam/issues/19737): classes that inherit
# from RunnerApiFn are expected to accept a PipelineContext for
# from_runner_api/to_runner_api. Determine how to accomodate this.
pack_combine_fn = combiners.SingleInputTupleCombineFn(
*[
core.CombineFn.from_runner_api(combine_payload.combine_fn, context) # type: ignore[arg-type]
for combine_payload in combine_payloads
]).to_runner_api(context) # type: ignore[arg-type]
pack_transform = beam_runner_api_pb2.PTransform(
unique_name=pack_transform_name + '/Pack',
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.composites.COMBINE_PER_KEY.urn,
payload=beam_runner_api_pb2.CombinePayload(
combine_fn=pack_combine_fn,
accumulator_coder_id=tuple_accumulator_coder_id).
SerializeToString()),
inputs={'in': input_pcoll_id},
# 'None' single output key follows convention for CombinePerKey.
outputs={'None': pack_pcoll_id},
environment_id=fused_stage.environment)
pack_stage = Stage(
pack_stage_name + '/Pack', [pack_transform],
downstream_side_inputs=fused_stage.downstream_side_inputs,
must_follow=fused_stage.must_follow,
parent=fused_stage.parent,
environment=fused_stage.environment)
yield pack_stage
# Set up Unpack stage
tags = [str(i) for i in range(len(output_pcoll_ids))]
pickled_do_fn_data = pickler.dumps((_UnpackFn(tags), (), {}, [], None))
unpack_transform = beam_runner_api_pb2.PTransform(
unique_name=pack_transform_name + '/Unpack',
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.primitives.PAR_DO.urn,
payload=beam_runner_api_pb2.ParDoPayload(
do_fn=beam_runner_api_pb2.FunctionSpec(
urn=python_urns.PICKLED_DOFN_INFO,
payload=pickled_do_fn_data)).SerializeToString()),
inputs={'in': pack_pcoll_id},
outputs=dict(zip(tags, output_pcoll_ids)),
environment_id=fused_stage.environment)
unpack_stage = Stage(
pack_stage_name + '/Unpack', [unpack_transform],
downstream_side_inputs=fused_stage.downstream_side_inputs,
must_follow=fused_stage.must_follow,
parent=fused_stage.parent,
environment=fused_stage.environment)
yield unpack_stage
def pack_combiners(stages, context, can_pack=None):
# type: (Iterable[Stage], TransformContext, Optional[Callable[[str], Union[bool, int]]]) -> Iterator[Stage]
if can_pack is None:
can_pack_names = {} # type: Dict[str, Union[bool, int]]
parents = context.parents_map()
def can_pack_fn(name: str) -> Union[bool, int]:
if name in can_pack_names:
return can_pack_names[name]
else:
transform = context.components.transforms[name]
if python_urns.APPLY_COMBINER_PACKING in transform.annotations:
try:
result = int(
transform.annotations[python_urns.APPLY_COMBINER_PACKING])
except ValueError:
result = True
elif name in parents:
result = can_pack_fn(parents[name])
else:
result = False
can_pack_names[name] = result
return result
can_pack = can_pack_fn
yield from pack_per_key_combiners(
_eliminate_common_key_with_none(stages, context, can_pack),
context,
can_pack)
def lift_combiners(stages, context):
# type: (List[Stage], TransformContext) -> Iterator[Stage]
"""Expands CombinePerKey into pre- and post-grouping stages.
... -> CombinePerKey -> ...
becomes
... -> PreCombine -> GBK -> MergeAccumulators -> ExtractOutput -> ...
"""
def is_compatible_with_combiner_lifting(trigger):
'''Returns whether this trigger is compatible with combiner lifting.
Certain triggers, such as those that fire after a certain number of
elements, need to observe every element, and as such are incompatible
with combiner lifting (which may aggregate several elements into one
before they reach the triggering code after shuffle).
'''
if trigger is None:
return True
elif trigger.WhichOneof('trigger') in (
'default',
'always',
'never',
'after_processing_time',
'after_synchronized_processing_time'):
return True
elif trigger.HasField('element_count'):
return trigger.element_count.element_count == 1
elif trigger.HasField('after_end_of_window'):
return is_compatible_with_combiner_lifting(
trigger.after_end_of_window.early_firings
) and is_compatible_with_combiner_lifting(
trigger.after_end_of_window.late_firings)
elif trigger.HasField('after_any'):
return all(
is_compatible_with_combiner_lifting(t)
for t in trigger.after_any.subtriggers)
elif trigger.HasField('repeat'):
return is_compatible_with_combiner_lifting(trigger.repeat.subtrigger)
else:
return False
def can_lift(combine_per_key_transform):
windowing = context.components.windowing_strategies[
context.components.pcollections[only_element(
list(combine_per_key_transform.inputs.values())
)].windowing_strategy_id]
return is_compatible_with_combiner_lifting(windowing.trigger)
def make_stage(base_stage, transform):
# type: (Stage, beam_runner_api_pb2.PTransform) -> Stage
return Stage(
transform.unique_name, [transform],
downstream_side_inputs=base_stage.downstream_side_inputs,
must_follow=base_stage.must_follow,
parent=base_stage.name,
environment=base_stage.environment)
def lifted_stages(stage):
transform = stage.transforms[0]
combine_payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.CombinePayload)
input_pcoll = context.components.pcollections[only_element(
list(transform.inputs.values()))]
output_pcoll = context.components.pcollections[only_element(
list(transform.outputs.values()))]
element_coder_id = input_pcoll.coder_id
element_coder = context.components.coders[element_coder_id]
key_coder_id, _ = element_coder.component_coder_ids
accumulator_coder_id = combine_payload.accumulator_coder_id
key_accumulator_coder = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(urn=common_urns.coders.KV.urn),
component_coder_ids=[key_coder_id, accumulator_coder_id])
key_accumulator_coder_id = context.add_or_get_coder_id(
key_accumulator_coder)
accumulator_iter_coder = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.coders.ITERABLE.urn),
component_coder_ids=[accumulator_coder_id])
accumulator_iter_coder_id = context.add_or_get_coder_id(
accumulator_iter_coder)
key_accumulator_iter_coder = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(urn=common_urns.coders.KV.urn),
component_coder_ids=[key_coder_id, accumulator_iter_coder_id])
key_accumulator_iter_coder_id = context.add_or_get_coder_id(
key_accumulator_iter_coder)
precombined_pcoll_id = unique_name(
context.components.pcollections, 'pcollection')
context.components.pcollections[precombined_pcoll_id].CopyFrom(
beam_runner_api_pb2.PCollection(
unique_name=transform.unique_name + '/Precombine.out',
coder_id=key_accumulator_coder_id,
windowing_strategy_id=input_pcoll.windowing_strategy_id,
is_bounded=input_pcoll.is_bounded))
grouped_pcoll_id = unique_name(
context.components.pcollections, 'pcollection')
context.components.pcollections[grouped_pcoll_id].CopyFrom(
beam_runner_api_pb2.PCollection(
unique_name=transform.unique_name + '/Group.out',
coder_id=key_accumulator_iter_coder_id,
windowing_strategy_id=output_pcoll.windowing_strategy_id,
is_bounded=output_pcoll.is_bounded))
merged_pcoll_id = unique_name(
context.components.pcollections, 'pcollection')
context.components.pcollections[merged_pcoll_id].CopyFrom(
beam_runner_api_pb2.PCollection(
unique_name=transform.unique_name + '/Merge.out',
coder_id=key_accumulator_coder_id,
windowing_strategy_id=output_pcoll.windowing_strategy_id,
is_bounded=output_pcoll.is_bounded))
yield make_stage(
stage,
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/Precombine',
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.combine_components.COMBINE_PER_KEY_PRECOMBINE.
urn,
payload=transform.spec.payload),
inputs=transform.inputs,
outputs={'out': precombined_pcoll_id},
environment_id=transform.environment_id))
yield make_stage(
stage,
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/Group',
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.primitives.GROUP_BY_KEY.urn),
inputs={'in': precombined_pcoll_id},
outputs={'out': grouped_pcoll_id}))
yield make_stage(
stage,
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/Merge',
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.combine_components.
COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn,
payload=transform.spec.payload),
inputs={'in': grouped_pcoll_id},
outputs={'out': merged_pcoll_id},
environment_id=transform.environment_id))
yield make_stage(
stage,
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/ExtractOutputs',
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.combine_components.
COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn,
payload=transform.spec.payload),
inputs={'in': merged_pcoll_id},
outputs=transform.outputs,
environment_id=transform.environment_id))
def unlifted_stages(stage):
transform = stage.transforms[0]
for sub in transform.subtransforms:
yield make_stage(stage, context.components.transforms[sub])
for stage in stages:
transform = only_transform(stage.transforms)
if transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn:
expansion = lifted_stages if can_lift(transform) else unlifted_stages
for substage in expansion(stage):
yield substage
else:
yield stage
def _lowest_common_ancestor(a, b, parents):
# type: (str, str, Dict[str, str]) -> Optional[str]
'''Returns the name of the lowest common ancestor of the two named stages.
The map of stage names to their parents' stage names should be provided
in parents. Note that stages are considered to be ancestors of themselves.
'''
assert a != b
def get_ancestors(name):
ancestor = name
while ancestor is not None:
yield ancestor
ancestor = parents.get(ancestor)
a_ancestors = set(get_ancestors(a))
for b_ancestor in get_ancestors(b):
if b_ancestor in a_ancestors:
return b_ancestor
return None
def _parent_for_fused_stages(stages, context):
# type: (Iterable[Stage], TransformContext) -> Optional[str]
'''Returns the name of the new parent for the fused stages.
The new parent is the lowest common ancestor of the fused stages that is not
contained in the set of stages to be fused. The provided context is used to
compute ancestors of stages.
'''
parents = context.parents_map()
# If any of the input stages were produced by fusion or an optimizer phase,
# or had its parent modified by an optimizer phase, its parent will not be
# be reflected in the PipelineContext yet, so we need to add it to the
# parents map.
for stage in stages:
parents[stage.name] = stage.parent
def reduce_fn(a, b):
# type: (Optional[str], Optional[str]) -> Optional[str]
if a is None or b is None:
return None
return _lowest_common_ancestor(a, b, parents)
stage_names = [stage.name for stage in stages] # type: List[Optional[str]]
result = functools.reduce(reduce_fn, stage_names)
if result in stage_names:
result = parents.get(result)
return result
def expand_sdf(stages, context):
# type: (Iterable[Stage], TransformContext) -> Iterator[Stage]
"""Transforms splitable DoFns into pair+split+read."""
for stage in stages:
transform = only_transform(stage.transforms)
if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
pardo_payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
if pardo_payload.restriction_coder_id:
def copy_like(protos, original, suffix='_copy', **kwargs):
if isinstance(original, str):
key = original
original = protos[original]
else:
key = 'component'
new_id = unique_name(protos, key + suffix)
protos[new_id].CopyFrom(original)
proto = protos[new_id]
for name, value in kwargs.items():
if isinstance(value, dict):
getattr(proto, name).clear()
getattr(proto, name).update(value)
elif isinstance(value, list):
del getattr(proto, name)[:]
getattr(proto, name).extend(value)
elif name == 'urn':
proto.spec.urn = value
elif name == 'payload':
proto.spec.payload = value
else:
setattr(proto, name, value)
if 'unique_name' not in kwargs and hasattr(proto, 'unique_name'):
proto.unique_name = unique_name(
{p.unique_name
for p in protos.values()},
original.unique_name + suffix)
return new_id
def make_stage(base_stage, transform_id, extra_must_follow=()):
# type: (Stage, str, Iterable[Stage]) -> Stage
transform = context.components.transforms[transform_id]
return Stage(
transform.unique_name, [transform],
base_stage.downstream_side_inputs,
union(base_stage.must_follow, frozenset(extra_must_follow)),
parent=base_stage.name,
environment=base_stage.environment)
main_input_tag = only_element(
tag for tag in transform.inputs.keys()
if tag not in pardo_payload.side_inputs)
main_input_id = transform.inputs[main_input_tag]
element_coder_id = context.components.pcollections[
main_input_id].coder_id
# Tuple[element, restriction]
paired_coder_id = context.add_or_get_coder_id(
beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.coders.KV.urn),
component_coder_ids=[
element_coder_id, pardo_payload.restriction_coder_id
]))
# Tuple[Tuple[element, restriction], double]
sized_coder_id = context.add_or_get_coder_id(
beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.coders.KV.urn),
component_coder_ids=[
paired_coder_id,
context.add_or_get_coder_id(
# context can be None here only because FloatCoder does
# not have components
coders.FloatCoder().to_runner_api(None), # type: ignore
'doubles_coder')
]))
paired_pcoll_id = copy_like(
context.components.pcollections,
main_input_id,
'_paired',
coder_id=paired_coder_id)
pair_transform_id = copy_like(
context.components.transforms,
transform,
unique_name=transform.unique_name + '/PairWithRestriction',
urn=common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn,
outputs={'out': paired_pcoll_id})
split_pcoll_id = copy_like(
context.components.pcollections,
main_input_id,
'_split',
coder_id=sized_coder_id)
split_transform_id = copy_like(
context.components.transforms,
transform,
unique_name=transform.unique_name + '/SplitAndSizeRestriction',
urn=common_urns.sdf_components.SPLIT_AND_SIZE_RESTRICTIONS.urn,
inputs=dict(transform.inputs, **{main_input_tag: paired_pcoll_id}),
outputs={'out': split_pcoll_id})
reshuffle_stage = None
if common_urns.composites.RESHUFFLE.urn in context.known_runner_urns:
reshuffle_pcoll_id = copy_like(
context.components.pcollections,
main_input_id,
'_reshuffle',
coder_id=sized_coder_id)
reshuffle_transform_id = copy_like(
context.components.transforms,
transform,
unique_name=transform.unique_name + '/Reshuffle',
urn=common_urns.composites.RESHUFFLE.urn,
payload=b'',
inputs=dict(transform.inputs, **{main_input_tag: split_pcoll_id}),
outputs={'out': reshuffle_pcoll_id})
reshuffle_stage = make_stage(stage, reshuffle_transform_id)
else:
reshuffle_pcoll_id = split_pcoll_id
reshuffle_transform_id = None
if context.is_drain:
truncate_pcoll_id = copy_like(
context.components.pcollections,
main_input_id,
'_truncate_restriction',
coder_id=sized_coder_id)
# Lengthprefix the truncate output.
context.length_prefix_pcoll_coders(truncate_pcoll_id)
truncate_transform_id = copy_like(
context.components.transforms,
transform,
unique_name=transform.unique_name + '/TruncateAndSizeRestriction',
urn=common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn,
inputs=dict(
transform.inputs, **{main_input_tag: reshuffle_pcoll_id}),
outputs={'out': truncate_pcoll_id})
process_transform_id = copy_like(
context.components.transforms,
transform,
unique_name=transform.unique_name + '/Process',
urn=common_urns.sdf_components.
PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
inputs=dict(
transform.inputs, **{main_input_tag: truncate_pcoll_id}))
else:
process_transform_id = copy_like(
context.components.transforms,
transform,
unique_name=transform.unique_name + '/Process',
urn=common_urns.sdf_components.
PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
inputs=dict(
transform.inputs, **{main_input_tag: reshuffle_pcoll_id}))
yield make_stage(stage, pair_transform_id)
split_stage = make_stage(stage, split_transform_id)
yield split_stage
if reshuffle_stage:
yield reshuffle_stage
if context.is_drain:
yield make_stage(
stage, truncate_transform_id, extra_must_follow=[split_stage])
yield make_stage(stage, process_transform_id)
else:
yield make_stage(
stage, process_transform_id, extra_must_follow=[split_stage])
else:
yield stage
else:
yield stage
def expand_gbk(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> Iterator[Stage]
"""Transforms each GBK into a write followed by a read."""
for stage in stages:
transform = only_transform(stage.transforms)
if transform.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn:
for pcoll_id in transform.inputs.values():
pipeline_context.length_prefix_pcoll_coders(pcoll_id)
for pcoll_id in transform.outputs.values():
if pipeline_context.use_state_iterables:
pipeline_context.components.pcollections[
pcoll_id].coder_id = pipeline_context.with_state_iterables(
pipeline_context.components.pcollections[pcoll_id].coder_id)
pipeline_context.length_prefix_pcoll_coders(pcoll_id)
# This is used later to correlate the read and write.
transform_id = stage.name
if transform != pipeline_context.components.transforms.get(transform_id):
transform_id = unique_name(
pipeline_context.components.transforms, stage.name)
pipeline_context.components.transforms[transform_id].CopyFrom(transform)
grouping_buffer = create_buffer_id(transform_id, kind='group')
gbk_write = Stage(
transform.unique_name + '/Write',
[
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/Write',
inputs=transform.inputs,
spec=beam_runner_api_pb2.FunctionSpec(
urn=bundle_processor.DATA_OUTPUT_URN,
payload=grouping_buffer))
],
downstream_side_inputs=frozenset(),
must_follow=stage.must_follow)
yield gbk_write
yield Stage(
transform.unique_name + '/Read',
[
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/Read',
outputs=transform.outputs,
spec=beam_runner_api_pb2.FunctionSpec(
urn=bundle_processor.DATA_INPUT_URN,
payload=grouping_buffer))
],
downstream_side_inputs=stage.downstream_side_inputs,
must_follow=union(frozenset([gbk_write]), stage.must_follow))
else:
yield stage
def fix_flatten_coders(
stages, pipeline_context, identity_urn=bundle_processor.IDENTITY_DOFN_URN):
# type: (Iterable[Stage], TransformContext, str) -> Iterator[Stage]
"""Ensures that the inputs of Flatten have the same coders as the output.
"""
pcollections = pipeline_context.components.pcollections
for stage in stages:
transform = only_element(stage.transforms)
if transform.spec.urn == common_urns.primitives.FLATTEN.urn:
output_pcoll_id = only_element(transform.outputs.values())
output_coder_id = pcollections[output_pcoll_id].coder_id
for local_in, pcoll_in in list(transform.inputs.items()):
if pcollections[pcoll_in].coder_id != output_coder_id:
# Flatten requires that all its inputs be materialized with the
# same coder as its output. Add stages to transcode flatten
# inputs that use different coders.
transcoded_pcollection = unique_name(
pcollections,
transform.unique_name + '/Transcode/' + local_in + '/out')
transcode_name = unique_name(
pipeline_context.components.transforms,
transform.unique_name + '/Transcode/' + local_in)
yield Stage(
transcode_name,
[
beam_runner_api_pb2.PTransform(
unique_name=transcode_name,
inputs={local_in: pcoll_in},
outputs={'out': transcoded_pcollection},
spec=beam_runner_api_pb2.FunctionSpec(urn=identity_urn),
environment_id=transform.environment_id)
],
downstream_side_inputs=frozenset(),
must_follow=stage.must_follow)
pcollections[transcoded_pcollection].CopyFrom(pcollections[pcoll_in])
pcollections[transcoded_pcollection].unique_name = (
transcoded_pcollection)
pcollections[transcoded_pcollection].coder_id = output_coder_id
transform.inputs[local_in] = transcoded_pcollection
yield stage
def sink_flattens(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> Iterator[Stage]
"""Sink flattens and remove them from the graph.
A flatten that cannot be sunk/fused away becomes multiple writes (to the
same logical sink) followed by a read.
"""
# TODO(robertwb): Actually attempt to sink rather than always materialize.
# TODO(robertwb): Possibly fuse multi-input flattens into one of the stages.
for stage in fix_flatten_coders(stages,
pipeline_context,
common_urns.primitives.FLATTEN.urn):
transform = only_element(stage.transforms)
if (transform.spec.urn == common_urns.primitives.FLATTEN.urn and
len(transform.inputs) > 1):
# This is used later to correlate the read and writes.
buffer_id = create_buffer_id(transform.unique_name)
flatten_writes = [] # type: List[Stage]
for local_in, pcoll_in in transform.inputs.items():
flatten_write = Stage(
transform.unique_name + '/Write/' + local_in,
[
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/Write/' + local_in,
inputs={local_in: pcoll_in},
spec=beam_runner_api_pb2.FunctionSpec(
urn=bundle_processor.DATA_OUTPUT_URN,
payload=buffer_id),
environment_id=transform.environment_id)
],
downstream_side_inputs=frozenset(),
must_follow=stage.must_follow)
flatten_writes.append(flatten_write)
yield flatten_write
yield Stage(
transform.unique_name + '/Read',
[
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/Read',
outputs=transform.outputs,
spec=beam_runner_api_pb2.FunctionSpec(
urn=bundle_processor.DATA_INPUT_URN, payload=buffer_id),
environment_id=transform.environment_id)
],
downstream_side_inputs=stage.downstream_side_inputs,
must_follow=union(frozenset(flatten_writes), stage.must_follow))
else:
yield stage
def greedily_fuse(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> FrozenSet[Stage]
"""Places transforms sharing an edge in the same stage, whenever possible.
"""
producers_by_pcoll = {} # type: Dict[str, Stage]
consumers_by_pcoll = collections.defaultdict(
list) # type: DefaultDict[str, List[Stage]]
# Used to always reference the correct stage as the producer and
# consumer maps are not updated when stages are fused away.
replacements = {} # type: Dict[Stage, Stage]
def replacement(s):
old_ss = []
while s in replacements:
old_ss.append(s)
s = replacements[s]
for old_s in old_ss[:-1]:
replacements[old_s] = s
return s
def fuse(producer, consumer):
fused = producer.fuse(consumer, pipeline_context)
replacements[producer] = fused
replacements[consumer] = fused
# First record the producers and consumers of each PCollection.
for stage in stages:
for transform in stage.transforms:
for input in transform.inputs.values():
consumers_by_pcoll[input].append(stage)
for output in transform.outputs.values():
producers_by_pcoll[output] = stage
# Now try to fuse away all pcollections.
for pcoll, producer in producers_by_pcoll.items():
write_pcoll = None
for consumer in consumers_by_pcoll[pcoll]:
producer = replacement(producer)
consumer = replacement(consumer)
# Update consumer.must_follow set, as it's used in can_fuse.
consumer.must_follow = frozenset(
replacement(s) for s in consumer.must_follow)
if producer.can_fuse(consumer, pipeline_context):
fuse(producer, consumer)
else:
# If we can't fuse, do a read + write.
pipeline_context.length_prefix_pcoll_coders(pcoll)
buffer_id = create_buffer_id(pcoll)
if write_pcoll is None:
write_pcoll = Stage(
pcoll + '/Write',
[
beam_runner_api_pb2.PTransform(
unique_name=pcoll + '/Write',
inputs={'in': pcoll},
spec=beam_runner_api_pb2.FunctionSpec(
urn=bundle_processor.DATA_OUTPUT_URN,
payload=buffer_id))
],
downstream_side_inputs=producer.downstream_side_inputs)
fuse(producer, write_pcoll)
if consumer.has_as_main_input(pcoll):
read_pcoll = Stage(
pcoll + '/Read',
[
beam_runner_api_pb2.PTransform(
unique_name=pcoll + '/Read',
outputs={'out': pcoll},
spec=beam_runner_api_pb2.FunctionSpec(
urn=bundle_processor.DATA_INPUT_URN,
payload=buffer_id))
],
downstream_side_inputs=consumer.downstream_side_inputs,
must_follow=frozenset([write_pcoll]))
fuse(read_pcoll, consumer)
else:
consumer.must_follow = union(
consumer.must_follow, frozenset([write_pcoll]))
# Everything that was originally a stage or a replacement, but wasn't
# replaced, should be in the final graph.
final_stages = frozenset(stages).union(list(replacements.values()))\
.difference(list(replacements))
for stage in final_stages:
# Update all references to their final values before throwing
# the replacement data away.
stage.must_follow = frozenset(replacement(s) for s in stage.must_follow)
# Two reads of the same stage may have been fused. This is unneeded.
stage.deduplicate_read()
return final_stages
def read_to_impulse(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> Iterator[Stage]
"""Translates Read operations into Impulse operations."""
for stage in stages:
# First map Reads, if any, to Impulse + triggered read op.
for transform in list(stage.transforms):
if transform.spec.urn == common_urns.deprecated_primitives.READ.urn:
read_pc = only_element(transform.outputs.values())
read_pc_proto = pipeline_context.components.pcollections[read_pc]
impulse_pc = unique_name(
pipeline_context.components.pcollections, 'Impulse')
pipeline_context.components.pcollections[impulse_pc].CopyFrom(
beam_runner_api_pb2.PCollection(
unique_name=impulse_pc,
coder_id=pipeline_context.bytes_coder_id,
windowing_strategy_id=read_pc_proto.windowing_strategy_id,
is_bounded=read_pc_proto.is_bounded))
stage.transforms.remove(transform)
# TODO(robertwb): If this goes multi-process before fn-api
# read is default, expand into split + reshuffle + read.
stage.transforms.append(
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/Impulse',
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.primitives.IMPULSE.urn),
outputs={'out': impulse_pc}))
stage.transforms.append(
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name,
spec=beam_runner_api_pb2.FunctionSpec(
urn=python_urns.IMPULSE_READ_TRANSFORM,
payload=transform.spec.payload),
inputs={'in': impulse_pc},
outputs={'out': read_pc}))
yield stage
def impulse_to_input(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> Iterator[Stage]
"""Translates Impulse operations into GRPC reads."""
for stage in stages:
for transform in list(stage.transforms):
if transform.spec.urn == common_urns.primitives.IMPULSE.urn:
stage.transforms.remove(transform)
stage.transforms.append(
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name,
spec=beam_runner_api_pb2.FunctionSpec(
urn=bundle_processor.DATA_INPUT_URN,
payload=IMPULSE_BUFFER),
outputs=transform.outputs))
yield stage
def extract_impulse_stages(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> Iterator[Stage]
"""Splits fused Impulse operations into their own stage."""
for stage in stages:
for transform in list(stage.transforms):
if transform.spec.urn == common_urns.primitives.IMPULSE.urn:
stage.transforms.remove(transform)
yield Stage(
transform.unique_name,
transforms=[transform],
downstream_side_inputs=stage.downstream_side_inputs,
must_follow=stage.must_follow,
parent=stage.parent)
if stage.transforms:
yield stage
def remove_data_plane_ops(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> Iterator[Stage]
for stage in stages:
for transform in list(stage.transforms):
if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
bundle_processor.DATA_OUTPUT_URN):
stage.transforms.remove(transform)
if stage.transforms:
yield stage
def setup_timer_mapping(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> Iterator[Stage]
"""Set up a mapping of {transform_id: [timer_ids]} for each stage.
"""
for stage in stages:
for transform in stage.transforms:
if transform.spec.urn in PAR_DO_URNS:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
for timer_family_id in payload.timer_family_specs.keys():
stage.timers.add((transform.unique_name, timer_family_id))
yield stage
def sort_stages(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> List[Stage]
"""Order stages suitable for sequential execution.
"""
all_stages = set(stages)
seen = set() # type: Set[Stage]
ordered = []
producers = {
pcoll: stage
for stage in all_stages for t in stage.transforms
for pcoll in t.outputs.values()
}
def process(stage):
if stage not in seen:
seen.add(stage)
if stage not in all_stages:
return
for prev in stage.must_follow:
process(prev)
stage_outputs = set(
pcoll for transform in stage.transforms
for pcoll in transform.outputs.values())
for transform in stage.transforms:
for pcoll in transform.inputs.values():
if pcoll not in stage_outputs:
process(producers[pcoll])
ordered.append(stage)
for stage in stages:
process(stage)
return ordered
def populate_data_channel_coders(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> Iterable[Stage]
"""Populate coders for GRPC input and output ports."""
for stage in stages:
for transform in stage.transforms:
if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
bundle_processor.DATA_OUTPUT_URN):
if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
sdk_pcoll_id = only_element(transform.outputs.values())
else:
sdk_pcoll_id = only_element(transform.inputs.values())
pipeline_context.add_data_channel_coder(sdk_pcoll_id)
return stages
def add_impulse_to_dangling_transforms(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> Iterable[Stage]
"""Populate coders for GRPC input and output ports."""
for stage in stages:
for transform in stage.transforms:
if len(transform.inputs
) == 0 and transform.spec.urn != bundle_processor.DATA_INPUT_URN:
# We look through the various stages in the DAG, and transforms. If we
# see a transform that has no inputs whatsoever.
impulse_pc = unique_name(
pipeline_context.components.pcollections, 'Impulse')
output_pcoll = pipeline_context.components.pcollections[next(
iter(transform.outputs.values()))]
pipeline_context.components.pcollections[impulse_pc].CopyFrom(
beam_runner_api_pb2.PCollection(
unique_name=impulse_pc,
coder_id=pipeline_context.bytes_coder_id,
windowing_strategy_id=output_pcoll.windowing_strategy_id,
is_bounded=output_pcoll.is_bounded))
transform.inputs['in'] = impulse_pc
stage.transforms.append(
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name,
spec=beam_runner_api_pb2.FunctionSpec(
urn=bundle_processor.DATA_INPUT_URN,
payload=IMPULSE_BUFFER),
outputs={'out': impulse_pc}))
yield stage
def union(a, b):
# Minimize the number of distinct sets.
if not a or a == b:
return b
elif not b:
return a
else:
return frozenset.union(a, b)
_global_counter = 0
def unique_name(existing, prefix):
# type: (Optional[Container[str]], str) -> str
if existing is None:
global _global_counter
_global_counter += 1
return '%s_%d' % (prefix, _global_counter)
elif prefix in existing:
counter = 0
while True:
counter += 1
prefix_counter = prefix + "_%s" % counter
if prefix_counter not in existing:
return prefix_counter
else:
return prefix
def only_element(iterable):
# type: (Iterable[T]) -> T
element, = iterable
return element
def only_transform(transforms):
# type: (List[beam_runner_api_pb2.PTransform]) -> beam_runner_api_pb2.PTransform
assert len(transforms) == 1
return transforms[0]
def create_buffer_id(name, kind='materialize'):
# type: (str, str) -> bytes
return ('%s:%s' % (kind, name)).encode('utf-8')
def split_buffer_id(buffer_id):
# type: (bytes) -> Tuple[str, str]
"""A buffer id is "kind:pcollection_id". Split into (kind, pcoll_id). """
kind, pcoll_id = buffer_id.decode('utf-8').split(':', 1)
return kind, pcoll_id