[BEAM-8657] Avoid lifting combiners for incompatible triggers.
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
index 4f3e2f9..338ad40 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
@@ -588,122 +588,168 @@
 
   ... -> PreCombine -> GBK -> MergeAccumulators -> ExtractOutput -> ...
   """
+  def is_compatible_with_combiner_lifting(trigger):
+    if trigger is None:
+      return True
+    elif trigger.WhichOneof('trigger') in (
+        'default', 'always', 'never', 'after_processing_time',
+        'after_synchronized_processing_time'):
+      return True
+    elif trigger.HasField('element_count'):
+      return trigger.element_count.element_count == 1
+    elif trigger.HasField('after_end_of_window'):
+      return is_compatible_with_combiner_lifting(
+          trigger.after_end_of_window.early_firings
+          ) and is_compatible_with_combiner_lifting(
+              trigger.after_end_of_window.late_firings)
+    elif trigger.HasField('after_any'):
+      return all(
+          is_compatible_with_combiner_lifting(t)
+          for t in trigger.after_any.subtriggers)
+    elif trigger.HasField('repeat'):
+      return is_compatible_with_combiner_lifting(trigger.repeat.subtrigger)
+    else:
+      return False
+
+  def can_lift(combine_per_key_transform):
+    windowing = context.components.windowing_strategies[
+        context.components.pcollections[
+            only_element(list(combine_per_key_transform.inputs.values()))
+        ].windowing_strategy_id]
+    if windowing.output_time != beam_runner_api_pb2.OutputTime.END_OF_WINDOW:
+      # This depends on the spec of PartialGroupByKey.
+      return False
+    elif not is_compatible_with_combiner_lifting(windowing.trigger):
+      return False
+    else:
+      return True
+
+  def make_stage(base_stage, transform):
+    return Stage(
+        transform.unique_name,
+        [transform],
+        downstream_side_inputs=base_stage.downstream_side_inputs,
+        must_follow=base_stage.must_follow,
+        parent=base_stage,
+        environment=base_stage.environment)
+
+  def lifted_stages(stage):
+    transform = stage.transforms[0]
+    combine_payload = proto_utils.parse_Bytes(
+        transform.spec.payload, beam_runner_api_pb2.CombinePayload)
+
+    input_pcoll = context.components.pcollections[only_element(
+        list(transform.inputs.values()))]
+    output_pcoll = context.components.pcollections[only_element(
+        list(transform.outputs.values()))]
+
+    element_coder_id = input_pcoll.coder_id
+    element_coder = context.components.coders[element_coder_id]
+    key_coder_id, _ = element_coder.component_coder_ids
+    accumulator_coder_id = combine_payload.accumulator_coder_id
+
+    key_accumulator_coder = beam_runner_api_pb2.Coder(
+        spec=beam_runner_api_pb2.FunctionSpec(
+            urn=common_urns.coders.KV.urn),
+        component_coder_ids=[key_coder_id, accumulator_coder_id])
+    key_accumulator_coder_id = context.add_or_get_coder_id(
+        key_accumulator_coder)
+
+    accumulator_iter_coder = beam_runner_api_pb2.Coder(
+        spec=beam_runner_api_pb2.FunctionSpec(
+            urn=common_urns.coders.ITERABLE.urn),
+        component_coder_ids=[accumulator_coder_id])
+    accumulator_iter_coder_id = context.add_or_get_coder_id(
+        accumulator_iter_coder)
+
+    key_accumulator_iter_coder = beam_runner_api_pb2.Coder(
+        spec=beam_runner_api_pb2.FunctionSpec(
+            urn=common_urns.coders.KV.urn),
+        component_coder_ids=[key_coder_id, accumulator_iter_coder_id])
+    key_accumulator_iter_coder_id = context.add_or_get_coder_id(
+        key_accumulator_iter_coder)
+
+    precombined_pcoll_id = unique_name(
+        context.components.pcollections, 'pcollection')
+    context.components.pcollections[precombined_pcoll_id].CopyFrom(
+        beam_runner_api_pb2.PCollection(
+            unique_name=transform.unique_name + '/Precombine.out',
+            coder_id=key_accumulator_coder_id,
+            windowing_strategy_id=input_pcoll.windowing_strategy_id,
+            is_bounded=input_pcoll.is_bounded))
+
+    grouped_pcoll_id = unique_name(
+        context.components.pcollections, 'pcollection')
+    context.components.pcollections[grouped_pcoll_id].CopyFrom(
+        beam_runner_api_pb2.PCollection(
+            unique_name=transform.unique_name + '/Group.out',
+            coder_id=key_accumulator_iter_coder_id,
+            windowing_strategy_id=output_pcoll.windowing_strategy_id,
+            is_bounded=output_pcoll.is_bounded))
+
+    merged_pcoll_id = unique_name(
+        context.components.pcollections, 'pcollection')
+    context.components.pcollections[merged_pcoll_id].CopyFrom(
+        beam_runner_api_pb2.PCollection(
+            unique_name=transform.unique_name + '/Merge.out',
+            coder_id=key_accumulator_coder_id,
+            windowing_strategy_id=output_pcoll.windowing_strategy_id,
+            is_bounded=output_pcoll.is_bounded))
+
+    yield make_stage(
+        stage,
+        beam_runner_api_pb2.PTransform(
+            unique_name=transform.unique_name + '/Precombine',
+            spec=beam_runner_api_pb2.FunctionSpec(
+                urn=common_urns.combine_components
+                .COMBINE_PER_KEY_PRECOMBINE.urn,
+                payload=transform.spec.payload),
+            inputs=transform.inputs,
+            outputs={'out': precombined_pcoll_id}))
+
+    yield make_stage(
+        stage,
+        beam_runner_api_pb2.PTransform(
+            unique_name=transform.unique_name + '/Group',
+            spec=beam_runner_api_pb2.FunctionSpec(
+                urn=common_urns.primitives.GROUP_BY_KEY.urn),
+            inputs={'in': precombined_pcoll_id},
+            outputs={'out': grouped_pcoll_id}))
+
+    yield make_stage(
+        stage,
+        beam_runner_api_pb2.PTransform(
+            unique_name=transform.unique_name + '/Merge',
+            spec=beam_runner_api_pb2.FunctionSpec(
+                urn=common_urns.combine_components
+                .COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn,
+                payload=transform.spec.payload),
+            inputs={'in': grouped_pcoll_id},
+            outputs={'out': merged_pcoll_id}))
+
+    yield make_stage(
+        stage,
+        beam_runner_api_pb2.PTransform(
+            unique_name=transform.unique_name + '/ExtractOutputs',
+            spec=beam_runner_api_pb2.FunctionSpec(
+                urn=common_urns.combine_components
+                .COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn,
+                payload=transform.spec.payload),
+            inputs={'in': merged_pcoll_id},
+            outputs=transform.outputs))
+
+  def unlifted_stages(stage):
+    transform = stage.transforms[0]
+    for sub in transform.subtransforms:
+      yield make_stage(stage, context.components.transforms[sub])
+
   for stage in stages:
     assert len(stage.transforms) == 1
     transform = stage.transforms[0]
     if transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn:
-      combine_payload = proto_utils.parse_Bytes(
-          transform.spec.payload, beam_runner_api_pb2.CombinePayload)
-
-      input_pcoll = context.components.pcollections[only_element(
-          list(transform.inputs.values()))]
-      output_pcoll = context.components.pcollections[only_element(
-          list(transform.outputs.values()))]
-
-      element_coder_id = input_pcoll.coder_id
-      element_coder = context.components.coders[element_coder_id]
-      key_coder_id, _ = element_coder.component_coder_ids
-      accumulator_coder_id = combine_payload.accumulator_coder_id
-
-      key_accumulator_coder = beam_runner_api_pb2.Coder(
-          spec=beam_runner_api_pb2.FunctionSpec(
-              urn=common_urns.coders.KV.urn),
-          component_coder_ids=[key_coder_id, accumulator_coder_id])
-      key_accumulator_coder_id = context.add_or_get_coder_id(
-          key_accumulator_coder)
-
-      accumulator_iter_coder = beam_runner_api_pb2.Coder(
-          spec=beam_runner_api_pb2.FunctionSpec(
-              urn=common_urns.coders.ITERABLE.urn),
-          component_coder_ids=[accumulator_coder_id])
-      accumulator_iter_coder_id = context.add_or_get_coder_id(
-          accumulator_iter_coder)
-
-      key_accumulator_iter_coder = beam_runner_api_pb2.Coder(
-          spec=beam_runner_api_pb2.FunctionSpec(
-              urn=common_urns.coders.KV.urn),
-          component_coder_ids=[key_coder_id, accumulator_iter_coder_id])
-      key_accumulator_iter_coder_id = context.add_or_get_coder_id(
-          key_accumulator_iter_coder)
-
-      precombined_pcoll_id = unique_name(
-          context.components.pcollections, 'pcollection')
-      context.components.pcollections[precombined_pcoll_id].CopyFrom(
-          beam_runner_api_pb2.PCollection(
-              unique_name=transform.unique_name + '/Precombine.out',
-              coder_id=key_accumulator_coder_id,
-              windowing_strategy_id=input_pcoll.windowing_strategy_id,
-              is_bounded=input_pcoll.is_bounded))
-
-      grouped_pcoll_id = unique_name(
-          context.components.pcollections, 'pcollection')
-      context.components.pcollections[grouped_pcoll_id].CopyFrom(
-          beam_runner_api_pb2.PCollection(
-              unique_name=transform.unique_name + '/Group.out',
-              coder_id=key_accumulator_iter_coder_id,
-              windowing_strategy_id=output_pcoll.windowing_strategy_id,
-              is_bounded=output_pcoll.is_bounded))
-
-      merged_pcoll_id = unique_name(
-          context.components.pcollections, 'pcollection')
-      context.components.pcollections[merged_pcoll_id].CopyFrom(
-          beam_runner_api_pb2.PCollection(
-              unique_name=transform.unique_name + '/Merge.out',
-              coder_id=key_accumulator_coder_id,
-              windowing_strategy_id=output_pcoll.windowing_strategy_id,
-              is_bounded=output_pcoll.is_bounded))
-
-      def make_stage(base_stage, transform):
-        return Stage(
-            transform.unique_name,
-            [transform],
-            downstream_side_inputs=base_stage.downstream_side_inputs,
-            must_follow=base_stage.must_follow,
-            parent=base_stage,
-            environment=base_stage.environment)
-
-      yield make_stage(
-          stage,
-          beam_runner_api_pb2.PTransform(
-              unique_name=transform.unique_name + '/Precombine',
-              spec=beam_runner_api_pb2.FunctionSpec(
-                  urn=common_urns.combine_components
-                  .COMBINE_PER_KEY_PRECOMBINE.urn,
-                  payload=transform.spec.payload),
-              inputs=transform.inputs,
-              outputs={'out': precombined_pcoll_id}))
-
-      yield make_stage(
-          stage,
-          beam_runner_api_pb2.PTransform(
-              unique_name=transform.unique_name + '/Group',
-              spec=beam_runner_api_pb2.FunctionSpec(
-                  urn=common_urns.primitives.GROUP_BY_KEY.urn),
-              inputs={'in': precombined_pcoll_id},
-              outputs={'out': grouped_pcoll_id}))
-
-      yield make_stage(
-          stage,
-          beam_runner_api_pb2.PTransform(
-              unique_name=transform.unique_name + '/Merge',
-              spec=beam_runner_api_pb2.FunctionSpec(
-                  urn=common_urns.combine_components
-                  .COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn,
-                  payload=transform.spec.payload),
-              inputs={'in': grouped_pcoll_id},
-              outputs={'out': merged_pcoll_id}))
-
-      yield make_stage(
-          stage,
-          beam_runner_api_pb2.PTransform(
-              unique_name=transform.unique_name + '/ExtractOutputs',
-              spec=beam_runner_api_pb2.FunctionSpec(
-                  urn=common_urns.combine_components
-                  .COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn,
-                  payload=transform.spec.payload),
-              inputs={'in': merged_pcoll_id},
-              outputs=transform.outputs))
-
+      expansion = lifted_stages if can_lift(transform) else unlifted_stages
+      for substage in expansion(stage):
+        yield substage
     else:
       yield stage