[BEAM-8657] Avoid lifting combiners for incompatible triggers.
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
index 4f3e2f9..338ad40 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
@@ -588,122 +588,168 @@
... -> PreCombine -> GBK -> MergeAccumulators -> ExtractOutput -> ...
"""
+ def is_compatible_with_combiner_lifting(trigger):
+ if trigger is None:
+ return True
+ elif trigger.WhichOneof('trigger') in (
+ 'default', 'always', 'never', 'after_processing_time',
+ 'after_synchronized_processing_time'):
+ return True
+ elif trigger.HasField('element_count'):
+ return trigger.element_count.element_count == 1
+ elif trigger.HasField('after_end_of_window'):
+ return is_compatible_with_combiner_lifting(
+ trigger.after_end_of_window.early_firings
+ ) and is_compatible_with_combiner_lifting(
+ trigger.after_end_of_window.late_firings)
+ elif trigger.HasField('after_any'):
+ return all(
+ is_compatible_with_combiner_lifting(t)
+ for t in trigger.after_any.subtriggers)
+ elif trigger.HasField('repeat'):
+ return is_compatible_with_combiner_lifting(trigger.repeat.subtrigger)
+ else:
+ return False
+
+ def can_lift(combine_per_key_transform):
+ windowing = context.components.windowing_strategies[
+ context.components.pcollections[
+ only_element(list(combine_per_key_transform.inputs.values()))
+ ].windowing_strategy_id]
+ if windowing.output_time != beam_runner_api_pb2.OutputTime.END_OF_WINDOW:
+ # This depends on the spec of PartialGroupByKey.
+ return False
+ elif not is_compatible_with_combiner_lifting(windowing.trigger):
+ return False
+ else:
+ return True
+
+ def make_stage(base_stage, transform):
+ return Stage(
+ transform.unique_name,
+ [transform],
+ downstream_side_inputs=base_stage.downstream_side_inputs,
+ must_follow=base_stage.must_follow,
+ parent=base_stage,
+ environment=base_stage.environment)
+
+ def lifted_stages(stage):
+ transform = stage.transforms[0]
+ combine_payload = proto_utils.parse_Bytes(
+ transform.spec.payload, beam_runner_api_pb2.CombinePayload)
+
+ input_pcoll = context.components.pcollections[only_element(
+ list(transform.inputs.values()))]
+ output_pcoll = context.components.pcollections[only_element(
+ list(transform.outputs.values()))]
+
+ element_coder_id = input_pcoll.coder_id
+ element_coder = context.components.coders[element_coder_id]
+ key_coder_id, _ = element_coder.component_coder_ids
+ accumulator_coder_id = combine_payload.accumulator_coder_id
+
+ key_accumulator_coder = beam_runner_api_pb2.Coder(
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=common_urns.coders.KV.urn),
+ component_coder_ids=[key_coder_id, accumulator_coder_id])
+ key_accumulator_coder_id = context.add_or_get_coder_id(
+ key_accumulator_coder)
+
+ accumulator_iter_coder = beam_runner_api_pb2.Coder(
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=common_urns.coders.ITERABLE.urn),
+ component_coder_ids=[accumulator_coder_id])
+ accumulator_iter_coder_id = context.add_or_get_coder_id(
+ accumulator_iter_coder)
+
+ key_accumulator_iter_coder = beam_runner_api_pb2.Coder(
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=common_urns.coders.KV.urn),
+ component_coder_ids=[key_coder_id, accumulator_iter_coder_id])
+ key_accumulator_iter_coder_id = context.add_or_get_coder_id(
+ key_accumulator_iter_coder)
+
+ precombined_pcoll_id = unique_name(
+ context.components.pcollections, 'pcollection')
+ context.components.pcollections[precombined_pcoll_id].CopyFrom(
+ beam_runner_api_pb2.PCollection(
+ unique_name=transform.unique_name + '/Precombine.out',
+ coder_id=key_accumulator_coder_id,
+ windowing_strategy_id=input_pcoll.windowing_strategy_id,
+ is_bounded=input_pcoll.is_bounded))
+
+ grouped_pcoll_id = unique_name(
+ context.components.pcollections, 'pcollection')
+ context.components.pcollections[grouped_pcoll_id].CopyFrom(
+ beam_runner_api_pb2.PCollection(
+ unique_name=transform.unique_name + '/Group.out',
+ coder_id=key_accumulator_iter_coder_id,
+ windowing_strategy_id=output_pcoll.windowing_strategy_id,
+ is_bounded=output_pcoll.is_bounded))
+
+ merged_pcoll_id = unique_name(
+ context.components.pcollections, 'pcollection')
+ context.components.pcollections[merged_pcoll_id].CopyFrom(
+ beam_runner_api_pb2.PCollection(
+ unique_name=transform.unique_name + '/Merge.out',
+ coder_id=key_accumulator_coder_id,
+ windowing_strategy_id=output_pcoll.windowing_strategy_id,
+ is_bounded=output_pcoll.is_bounded))
+
+ yield make_stage(
+ stage,
+ beam_runner_api_pb2.PTransform(
+ unique_name=transform.unique_name + '/Precombine',
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=common_urns.combine_components
+ .COMBINE_PER_KEY_PRECOMBINE.urn,
+ payload=transform.spec.payload),
+ inputs=transform.inputs,
+ outputs={'out': precombined_pcoll_id}))
+
+ yield make_stage(
+ stage,
+ beam_runner_api_pb2.PTransform(
+ unique_name=transform.unique_name + '/Group',
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=common_urns.primitives.GROUP_BY_KEY.urn),
+ inputs={'in': precombined_pcoll_id},
+ outputs={'out': grouped_pcoll_id}))
+
+ yield make_stage(
+ stage,
+ beam_runner_api_pb2.PTransform(
+ unique_name=transform.unique_name + '/Merge',
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=common_urns.combine_components
+ .COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn,
+ payload=transform.spec.payload),
+ inputs={'in': grouped_pcoll_id},
+ outputs={'out': merged_pcoll_id}))
+
+ yield make_stage(
+ stage,
+ beam_runner_api_pb2.PTransform(
+ unique_name=transform.unique_name + '/ExtractOutputs',
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=common_urns.combine_components
+ .COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn,
+ payload=transform.spec.payload),
+ inputs={'in': merged_pcoll_id},
+ outputs=transform.outputs))
+
+ def unlifted_stages(stage):
+ transform = stage.transforms[0]
+ for sub in transform.subtransforms:
+ yield make_stage(stage, context.components.transforms[sub])
+
for stage in stages:
assert len(stage.transforms) == 1
transform = stage.transforms[0]
if transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn:
- combine_payload = proto_utils.parse_Bytes(
- transform.spec.payload, beam_runner_api_pb2.CombinePayload)
-
- input_pcoll = context.components.pcollections[only_element(
- list(transform.inputs.values()))]
- output_pcoll = context.components.pcollections[only_element(
- list(transform.outputs.values()))]
-
- element_coder_id = input_pcoll.coder_id
- element_coder = context.components.coders[element_coder_id]
- key_coder_id, _ = element_coder.component_coder_ids
- accumulator_coder_id = combine_payload.accumulator_coder_id
-
- key_accumulator_coder = beam_runner_api_pb2.Coder(
- spec=beam_runner_api_pb2.FunctionSpec(
- urn=common_urns.coders.KV.urn),
- component_coder_ids=[key_coder_id, accumulator_coder_id])
- key_accumulator_coder_id = context.add_or_get_coder_id(
- key_accumulator_coder)
-
- accumulator_iter_coder = beam_runner_api_pb2.Coder(
- spec=beam_runner_api_pb2.FunctionSpec(
- urn=common_urns.coders.ITERABLE.urn),
- component_coder_ids=[accumulator_coder_id])
- accumulator_iter_coder_id = context.add_or_get_coder_id(
- accumulator_iter_coder)
-
- key_accumulator_iter_coder = beam_runner_api_pb2.Coder(
- spec=beam_runner_api_pb2.FunctionSpec(
- urn=common_urns.coders.KV.urn),
- component_coder_ids=[key_coder_id, accumulator_iter_coder_id])
- key_accumulator_iter_coder_id = context.add_or_get_coder_id(
- key_accumulator_iter_coder)
-
- precombined_pcoll_id = unique_name(
- context.components.pcollections, 'pcollection')
- context.components.pcollections[precombined_pcoll_id].CopyFrom(
- beam_runner_api_pb2.PCollection(
- unique_name=transform.unique_name + '/Precombine.out',
- coder_id=key_accumulator_coder_id,
- windowing_strategy_id=input_pcoll.windowing_strategy_id,
- is_bounded=input_pcoll.is_bounded))
-
- grouped_pcoll_id = unique_name(
- context.components.pcollections, 'pcollection')
- context.components.pcollections[grouped_pcoll_id].CopyFrom(
- beam_runner_api_pb2.PCollection(
- unique_name=transform.unique_name + '/Group.out',
- coder_id=key_accumulator_iter_coder_id,
- windowing_strategy_id=output_pcoll.windowing_strategy_id,
- is_bounded=output_pcoll.is_bounded))
-
- merged_pcoll_id = unique_name(
- context.components.pcollections, 'pcollection')
- context.components.pcollections[merged_pcoll_id].CopyFrom(
- beam_runner_api_pb2.PCollection(
- unique_name=transform.unique_name + '/Merge.out',
- coder_id=key_accumulator_coder_id,
- windowing_strategy_id=output_pcoll.windowing_strategy_id,
- is_bounded=output_pcoll.is_bounded))
-
- def make_stage(base_stage, transform):
- return Stage(
- transform.unique_name,
- [transform],
- downstream_side_inputs=base_stage.downstream_side_inputs,
- must_follow=base_stage.must_follow,
- parent=base_stage,
- environment=base_stage.environment)
-
- yield make_stage(
- stage,
- beam_runner_api_pb2.PTransform(
- unique_name=transform.unique_name + '/Precombine',
- spec=beam_runner_api_pb2.FunctionSpec(
- urn=common_urns.combine_components
- .COMBINE_PER_KEY_PRECOMBINE.urn,
- payload=transform.spec.payload),
- inputs=transform.inputs,
- outputs={'out': precombined_pcoll_id}))
-
- yield make_stage(
- stage,
- beam_runner_api_pb2.PTransform(
- unique_name=transform.unique_name + '/Group',
- spec=beam_runner_api_pb2.FunctionSpec(
- urn=common_urns.primitives.GROUP_BY_KEY.urn),
- inputs={'in': precombined_pcoll_id},
- outputs={'out': grouped_pcoll_id}))
-
- yield make_stage(
- stage,
- beam_runner_api_pb2.PTransform(
- unique_name=transform.unique_name + '/Merge',
- spec=beam_runner_api_pb2.FunctionSpec(
- urn=common_urns.combine_components
- .COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn,
- payload=transform.spec.payload),
- inputs={'in': grouped_pcoll_id},
- outputs={'out': merged_pcoll_id}))
-
- yield make_stage(
- stage,
- beam_runner_api_pb2.PTransform(
- unique_name=transform.unique_name + '/ExtractOutputs',
- spec=beam_runner_api_pb2.FunctionSpec(
- urn=common_urns.combine_components
- .COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn,
- payload=transform.spec.payload),
- inputs={'in': merged_pcoll_id},
- outputs=transform.outputs))
-
+ expansion = lifted_stages if can_lift(transform) else unlifted_stages
+ for substage in expansion(stage):
+ yield substage
else:
yield stage