[BEAM-3419] Support iterable on Dataflow runner when using the unified worker.
Note that all other portable runners are using iterable side inputs.
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
index ada700c..762b2a2 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
@@ -302,7 +302,7 @@
return SetPDoneVisitor(pipeline)
@staticmethod
- def side_input_visitor():
+ def side_input_visitor(use_unified_worker=False):
# Imported here to avoid circular dependencies.
# pylint: disable=wrong-import-order, wrong-import-position
from apache_beam.pipeline import PipelineVisitor
@@ -320,24 +320,32 @@
for ix, side_input in enumerate(transform_node.side_inputs):
access_pattern = side_input._side_input_data().access_pattern
if access_pattern == common_urns.side_inputs.ITERABLE.urn:
- # Add a map to ('', value) as Dataflow currently only handles
- # keyed side inputs.
- pipeline = side_input.pvalue.pipeline
- new_side_input = _DataflowIterableSideInput(side_input)
- new_side_input.pvalue = beam.pvalue.PCollection(
- pipeline,
- element_type=typehints.KV[
- bytes, side_input.pvalue.element_type],
- is_bounded=side_input.pvalue.is_bounded)
- parent = transform_node.parent or pipeline._root_transform()
- map_to_void_key = beam.pipeline.AppliedPTransform(
- pipeline,
- beam.Map(lambda x: (b'', x)),
- transform_node.full_label + '/MapToVoidKey%s' % ix,
- (side_input.pvalue,))
- new_side_input.pvalue.producer = map_to_void_key
- map_to_void_key.add_output(new_side_input.pvalue)
- parent.add_part(map_to_void_key)
+ if use_unified_worker:
+ # Patch up the access pattern to appease Dataflow when using
+ # the UW and hardcode the output type to be Any since
+ # the Dataflow JSON and pipeline proto can differ in coders
+ # which leads to encoding/decoding issues within the runner.
+ side_input.pvalue.element_type = typehints.Any
+ new_side_input = _DataflowIterableSideInput(side_input)
+ else:
+ # Add a map to ('', value) as Dataflow currently only handles
+ # keyed side inputs when using the JRH.
+ pipeline = side_input.pvalue.pipeline
+ new_side_input = _DataflowIterableAsMultimapSideInput(
+ side_input)
+ new_side_input.pvalue = beam.pvalue.PCollection(
+ pipeline,
+ element_type=typehints.KV[bytes,
+ side_input.pvalue.element_type],
+ is_bounded=side_input.pvalue.is_bounded)
+ parent = transform_node.parent or pipeline._root_transform()
+ map_to_void_key = beam.pipeline.AppliedPTransform(
+ pipeline, beam.Map(lambda x: (b'', x)),
+ transform_node.full_label + '/MapToVoidKey%s' % ix,
+ (side_input.pvalue,))
+ new_side_input.pvalue.producer = map_to_void_key
+ map_to_void_key.add_output(new_side_input.pvalue)
+ parent.add_part(map_to_void_key)
elif access_pattern == common_urns.side_inputs.MULTIMAP.urn:
# Ensure the input coder is a KV coder and patch up the
# access pattern to appease Dataflow.
@@ -397,7 +405,8 @@
# Convert all side inputs into a form acceptable to Dataflow.
if apiclient._use_fnapi(options):
- pipeline.visit(self.side_input_visitor())
+ pipeline.visit(
+ self.side_input_visitor(apiclient._use_unified_worker(options)))
# Performing configured PTransform overrides. Note that this is currently
# done before Runner API serialization, since the new proto needs to contain
@@ -1320,12 +1329,12 @@
return self._data
-class _DataflowIterableSideInput(_DataflowSideInput):
+class _DataflowIterableAsMultimapSideInput(_DataflowSideInput):
"""Wraps an iterable side input as dataflow-compatible side input."""
- def __init__(self, iterable_side_input):
+ def __init__(self, side_input):
# pylint: disable=protected-access
- side_input_data = iterable_side_input._side_input_data()
+ side_input_data = side_input._side_input_data()
assert (
side_input_data.access_pattern == common_urns.side_inputs.ITERABLE.urn)
iterable_view_fn = side_input_data.view_fn
@@ -1335,6 +1344,20 @@
lambda multimap: iterable_view_fn(multimap[b'']))
+class _DataflowIterableSideInput(_DataflowSideInput):
+ """Wraps an iterable side input as dataflow-compatible side input."""
+
+ def __init__(self, side_input):
+ # pylint: disable=protected-access
+ self.pvalue = side_input.pvalue
+ side_input_data = side_input._side_input_data()
+ assert (
+ side_input_data.access_pattern == common_urns.side_inputs.ITERABLE.urn)
+ self._data = beam.pvalue.SideInputData(common_urns.side_inputs.ITERABLE.urn,
+ side_input_data.window_mapping_fn,
+ side_input_data.view_fn)
+
+
class _DataflowMultimapSideInput(_DataflowSideInput):
"""Wraps a multimap side input as dataflow-compatible side input."""