[BEAM-1833] Preserve inputs names at graph construction and through proto transaltion. (#15202)
diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py
index 69362a4..d0aaa2f 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -61,6 +61,7 @@
from typing import FrozenSet
from typing import Iterable
from typing import List
+from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Set
@@ -271,7 +272,7 @@
output_replacements = {
} # type: Dict[AppliedPTransform, List[Tuple[pvalue.PValue, Optional[str]]]]
input_replacements = {
- } # type: Dict[AppliedPTransform, Sequence[Union[pvalue.PBegin, pvalue.PCollection]]]
+ } # type: Dict[AppliedPTransform, Mapping[str, Union[pvalue.PBegin, pvalue.PCollection]]]
side_input_replacements = {
} # type: Dict[AppliedPTransform, List[pvalue.AsSideInput]]
@@ -297,7 +298,7 @@
original_transform_node.parent,
replacement_transform,
original_transform_node.full_label,
- original_transform_node.inputs)
+ original_transform_node.main_inputs)
replacement_transform_node.resource_hints = (
original_transform_node.resource_hints)
@@ -437,11 +438,11 @@
output_replacements[transform_node].append((tag, replacement))
if replace_input:
- new_input = [
- input if not input in output_map else output_map[input]
- for input in transform_node.inputs
- ]
- input_replacements[transform_node] = new_input
+ new_inputs = {
+ tag: input if not input in output_map else output_map[input]
+ for (tag, input) in transform_node.main_inputs.items()
+ }
+ input_replacements[transform_node] = new_inputs
if replace_side_inputs:
new_side_inputs = []
@@ -670,15 +671,18 @@
pvalueish, inputs = transform._extract_input_pvalues(pvalueish)
try:
- inputs = tuple(inputs)
- for leaf_input in inputs:
- if not isinstance(leaf_input, pvalue.PValue):
- raise TypeError
+ if not isinstance(inputs, dict):
+ inputs = {str(ix): input for (ix, input) in enumerate(inputs)}
except TypeError:
raise NotImplementedError(
'Unable to extract PValue inputs from %s; either %s does not accept '
'inputs of this format, or it does not properly override '
'_extract_input_pvalues' % (pvalueish, transform))
+ for t, leaf_input in inputs.items():
+ if not isinstance(leaf_input, pvalue.PValue) or not isinstance(t, str):
+ raise NotImplementedError(
+ '%s does not properly override _extract_input_pvalues, '
+ 'returned %s from %s' % (transform, inputs, pvalueish))
current = AppliedPTransform(
self._current_transform(), transform, full_label, inputs)
@@ -705,7 +709,8 @@
if result.producer is None:
result.producer = current
- self._infer_result_type(transform, inputs, result)
+ # TODO(BEAM-1833): Pass full tuples dict.
+ self._infer_result_type(transform, tuple(inputs.values()), result)
assert isinstance(result.producer.inputs, tuple)
# The DoOutputsTuple adds the PCollection to the outputs when accessed
@@ -940,7 +945,7 @@
for id in proto.components.transforms:
transform = context.transforms.get_by_id(id)
if not transform.inputs and transform.transform.__class__ in has_pbegin:
- transform.inputs = (pvalue.PBegin(p), )
+ transform.main_inputs = {'None': pvalue.PBegin(p)}
if return_context:
return p, context # type: ignore # too complicated for now
@@ -1030,7 +1035,7 @@
parent, # type: Optional[AppliedPTransform]
transform, # type: Optional[ptransform.PTransform]
full_label, # type: str
- inputs, # type: Optional[Sequence[Union[pvalue.PBegin, pvalue.PCollection]]]
+ main_inputs, # type: Optional[Mapping[str, Union[pvalue.PBegin, pvalue.PCollection]]]
environment_id=None, # type: Optional[str]
annotations=None, # type: Optional[Dict[str, bytes]]
):
@@ -1043,7 +1048,7 @@
# reusing PTransform instances in different contexts (apply() calls) without
# any interference. This is particularly useful for composite transforms.
self.full_label = full_label
- self.inputs = inputs or ()
+ self.main_inputs = dict(main_inputs or {})
self.side_inputs = tuple() if transform is None else transform.side_inputs
self.outputs = {} # type: Dict[Union[str, int, None], pvalue.PValue]
@@ -1076,6 +1081,10 @@
}
self.annotations = annotations
+ @property
+ def inputs(self):
+ return tuple(self.main_inputs.values())
+
def __repr__(self):
# type: () -> str
return "%s(%s, %s)" % (
@@ -1109,8 +1118,8 @@
if isinstance(self.transform, external.ExternalTransform):
self.transform.replace_named_outputs(self.named_outputs())
- def replace_inputs(self, inputs):
- self.inputs = inputs
+ def replace_inputs(self, main_inputs):
+ self.main_inputs = main_inputs
# Importing locally to prevent circular dependency issues.
from apache_beam.transforms import external
@@ -1215,12 +1224,11 @@
def named_inputs(self):
# type: () -> Dict[str, pvalue.PValue]
- # TODO(BEAM-1833): Push names up into the sdk construction.
if self.transform is None:
- assert not self.inputs and not self.side_inputs
+ assert not self.main_inputs and not self.side_inputs
return {}
else:
- return self.transform._named_inputs(self.inputs, self.side_inputs)
+ return self.transform._named_inputs(self.main_inputs, self.side_inputs)
def named_outputs(self):
# type: () -> Dict[str, pvalue.PCollection]
@@ -1309,10 +1317,10 @@
pardo_payload = None
side_input_tags = []
- main_inputs = [
- context.pcollections.get_by_id(id) for tag,
- id in proto.inputs.items() if tag not in side_input_tags
- ]
+ main_inputs = {
+ tag: context.pcollections.get_by_id(id)
+ for (tag, id) in proto.inputs.items() if tag not in side_input_tags
+ }
transform = ptransform.PTransform.from_runner_api(proto, context)
if transform and proto.environment_id:
@@ -1334,7 +1342,7 @@
parent=None,
transform=transform,
full_label=proto.unique_name,
- inputs=main_inputs,
+ main_inputs=main_inputs,
environment_id=None,
annotations=proto.annotations)
diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py
index c2a23fd..731d5e3 100644
--- a/sdks/python/apache_beam/pipeline_test.py
+++ b/sdks/python/apache_beam/pipeline_test.py
@@ -972,6 +972,24 @@
for transform_id in runner_api_proto.components.transforms:
self.assertRegex(transform_id, r'[a-zA-Z0-9-_]+')
+ def test_input_names(self):
+ class MyPTransform(beam.PTransform):
+ def expand(self, pcolls):
+ return pcolls.values() | beam.Flatten()
+
+ p = beam.Pipeline()
+ input_names = set('ABC')
+ inputs = {x: p | x >> beam.Create([x]) for x in input_names}
+ inputs | MyPTransform() # pylint: disable=expression-not-assigned
+ runner_api_proto = Pipeline.to_runner_api(p)
+
+ for transform_proto in runner_api_proto.components.transforms.values():
+ if transform_proto.unique_name == 'MyPTransform':
+ self.assertEqual(set(transform_proto.inputs.keys()), input_names)
+ break
+ else:
+ self.fail('Unable to find transform.')
+
def test_display_data(self):
class MyParentTransform(beam.PTransform):
def expand(self, p):
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
index c112bdd..bbaf52c 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
@@ -305,7 +305,7 @@
parent,
beam.Map(lambda x: (b'', x)),
transform_node.full_label + '/MapToVoidKey%s' % ix,
- (side_input.pvalue, ))
+ {'input': side_input.pvalue})
new_side_input.pvalue.producer = map_to_void_key
map_to_void_key.add_output(new_side_input.pvalue, None)
parent.add_part(map_to_void_key)
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
index ec4edc9..e7ce71b 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
@@ -348,7 +348,8 @@
pcoll2.element_type = typehints.Any
pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any]
for pcoll in [pcoll1, pcoll2, pcoll3]:
- applied = AppliedPTransform(None, beam.GroupByKey(), "label", [pcoll])
+ applied = AppliedPTransform(
+ None, beam.GroupByKey(), "label", {'pcoll': pcoll})
applied.outputs[None] = PCollection(None)
common.group_by_key_input_visitor().visit_transform(applied)
self.assertEqual(
@@ -367,7 +368,7 @@
for pcoll in [pcoll1, pcoll2]:
with self.assertRaisesRegex(ValueError, err_msg):
common.group_by_key_input_visitor().visit_transform(
- AppliedPTransform(None, beam.GroupByKey(), "label", [pcoll]))
+ AppliedPTransform(None, beam.GroupByKey(), "label", {'in': pcoll}))
def test_group_by_key_input_visitor_for_non_gbk_transforms(self):
p = TestPipeline()
@@ -375,7 +376,7 @@
for transform in [beam.Flatten(), beam.Map(lambda x: x)]:
pcoll.element_type = typehints.Any
common.group_by_key_input_visitor().visit_transform(
- AppliedPTransform(None, transform, "label", [pcoll]))
+ AppliedPTransform(None, transform, "label", {'in': pcoll}))
self.assertEqual(pcoll.element_type, typehints.Any)
def test_flatten_input_with_visitor_with_single_input(self):
@@ -387,11 +388,11 @@
def _test_flatten_input_visitor(self, input_type, output_type, num_inputs):
p = TestPipeline()
- inputs = []
- for _ in range(num_inputs):
+ inputs = {}
+ for ix in range(num_inputs):
input_pcoll = PCollection(p)
input_pcoll.element_type = input_type
- inputs.append(input_pcoll)
+ inputs[str(ix)] = input_pcoll
output_pcoll = PCollection(p)
output_pcoll.element_type = output_type
@@ -399,7 +400,7 @@
flatten.add_output(output_pcoll, None)
DataflowRunner.flatten_input_visitor().visit_transform(flatten)
for _ in range(num_inputs):
- self.assertEqual(inputs[0].element_type, output_type)
+ self.assertEqual(inputs['0'].element_type, output_type)
def test_gbk_then_flatten_input_visitor(self):
p = TestPipeline(
@@ -442,7 +443,7 @@
z: (x, y, z),
beam.pvalue.AsSingleton(pc),
beam.pvalue.AsMultiMap(pc))
- applied_transform = AppliedPTransform(None, transform, "label", [pc])
+ applied_transform = AppliedPTransform(None, transform, "label", {'pc': pc})
DataflowRunner.side_input_visitor(
use_fn_api=True).visit_transform(applied_transform)
self.assertEqual(2, len(applied_transform.side_inputs))
diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py b/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py
index 456636c..448fca7 100644
--- a/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py
+++ b/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py
@@ -685,8 +685,8 @@
def visit_transform(self, transform_node):
if transform_node.inputs:
- input_list = list(transform_node.inputs)
- for i, input_pcoll in enumerate(input_list):
+ main_inputs = dict(transform_node.main_inputs)
+ for tag, input_pcoll in main_inputs.items():
key = self._pin.cache_key(input_pcoll)
# Replace the input pcollection with the cached pcollection (if it
@@ -694,9 +694,9 @@
if key in self._pin._cached_pcoll_read:
# Ignore this pcoll in the final pruned instrumented pipeline.
self._pin._ignored_targets.add(input_pcoll)
- input_list[i] = self._pin._cached_pcoll_read[key]
+ main_inputs[tag] = self._pin._cached_pcoll_read[key]
# Update the transform with its new inputs.
- transform_node.inputs = tuple(input_list)
+ transform_node.main_inputs = main_inputs
v = ReadCacheWireVisitor(self)
pipeline.visit(v)
diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py b/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
index 40d19a7..a3f91c0 100644
--- a/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
+++ b/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
@@ -297,11 +297,11 @@
def visit_transform(self, transform_node):
if transform_node.inputs:
- input_list = list(transform_node.inputs)
- for i in range(len(input_list)):
- if input_list[i] == init_pcoll:
- input_list[i] = cached_init_pcoll
- transform_node.inputs = tuple(input_list)
+ main_inputs = dict(transform_node.main_inputs)
+ for tag in main_inputs.keys():
+ if main_inputs[tag] == init_pcoll:
+ main_inputs[tag] = cached_init_pcoll
+ transform_node.main_inputs = main_inputs
v = TestReadCacheWireVisitor()
p_origin.visit(v)
diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py
index 788485cd..50f7715 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -50,6 +50,7 @@
from typing import Callable
from typing import Dict
from typing import List
+from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Tuple
@@ -253,7 +254,7 @@
return self.visit_nested(node)
-def get_named_nested_pvalues(pvalueish):
+def get_named_nested_pvalues(pvalueish, as_inputs=False):
if isinstance(pvalueish, tuple):
# Check to see if it's a named tuple.
fields = getattr(pvalueish, '_fields', None)
@@ -262,16 +263,22 @@
else:
tagged_values = enumerate(pvalueish)
elif isinstance(pvalueish, list):
+ if as_inputs:
+ # Full list treated as a list of value for eager evaluation.
+ yield None, pvalueish
+ return
tagged_values = enumerate(pvalueish)
elif isinstance(pvalueish, dict):
tagged_values = pvalueish.items()
else:
- if isinstance(pvalueish, (pvalue.PValue, pvalue.DoOutputsTuple)):
+ if as_inputs or isinstance(pvalueish,
+ (pvalue.PValue, pvalue.DoOutputsTuple)):
yield None, pvalueish
return
for tag, subvalue in tagged_values:
- for subtag, subsubvalue in get_named_nested_pvalues(subvalue):
+ for subtag, subsubvalue in get_named_nested_pvalues(
+ subvalue, as_inputs=as_inputs):
if subtag is None:
yield tag, subsubvalue
else:
@@ -569,6 +576,8 @@
def __ror__(self, left, label=None):
"""Used to apply this PTransform to non-PValues, e.g., a tuple."""
pvalueish, pvalues = self._extract_input_pvalues(left)
+ if isinstance(pvalues, dict):
+ pvalues = tuple(pvalues.values())
pipelines = [v.pipeline for v in pvalues if isinstance(v, pvalue.PValue)]
if pvalues and not pipelines:
deferred = False
@@ -597,8 +606,7 @@
# pylint: enable=wrong-import-order, wrong-import-position
replacements = {
id(v): p | 'CreatePInput%s' % ix >> Create(v, reshuffle=False)
- for ix,
- v in enumerate(pvalues)
+ for (ix, v) in enumerate(pvalues)
if not isinstance(v, pvalue.PValue) and v is not None
}
pvalueish = _SetInputPValues().visit(pvalueish, replacements)
@@ -628,19 +636,11 @@
if isinstance(pvalueish, pipeline.Pipeline):
pvalueish = pvalue.PBegin(pvalueish)
- def _dict_tuple_leaves(pvalueish):
- if isinstance(pvalueish, tuple):
- for a in pvalueish:
- for p in _dict_tuple_leaves(a):
- yield p
- elif isinstance(pvalueish, dict):
- for a in pvalueish.values():
- for p in _dict_tuple_leaves(a):
- yield p
- else:
- yield pvalueish
-
- return pvalueish, tuple(_dict_tuple_leaves(pvalueish))
+ return pvalueish, {
+ str(tag): value
+ for (tag, value) in get_named_nested_pvalues(
+ pvalueish, as_inputs=True)
+ }
def _pvaluish_from_dict(self, input_dict):
if len(input_dict) == 1:
@@ -648,16 +648,15 @@
else:
return input_dict
- def _named_inputs(self, inputs, side_inputs):
- # type: (Sequence[pvalue.PValue], Sequence[Any]) -> Dict[str, pvalue.PValue]
+ def _named_inputs(self, main_inputs, side_inputs):
+ # type: (Mapping[str, pvalue.PValue], Sequence[Any]) -> Dict[str, pvalue.PValue]
"""Returns the dictionary of named inputs (including side inputs) as they
should be named in the beam proto.
"""
- # TODO(BEAM-1833): Push names up into the sdk construction.
main_inputs = {
- str(ix): input
- for (ix, input) in enumerate(inputs)
+ tag: input
+ for (tag, input) in main_inputs.items()
if isinstance(input, pvalue.PCollection)
}
named_side_inputs = {(SIDE_INPUT_PREFIX + '%s') % ix: si.pvalue