[BEAM-8655] Run trigger transcript tests with combiner as well as GroupByKey.
diff --git a/sdks/python/apache_beam/transforms/trigger_test.py b/sdks/python/apache_beam/transforms/trigger_test.py
index fcce845..96843db 100644
--- a/sdks/python/apache_beam/transforms/trigger_test.py
+++ b/sdks/python/apache_beam/transforms/trigger_test.py
@@ -41,6 +41,7 @@
from apache_beam.testing.test_stream import TestStream
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
+from apache_beam.transforms import ptransform
from apache_beam.transforms import trigger
from apache_beam.transforms.core import Windowing
from apache_beam.transforms.trigger import AccumulationMode
@@ -807,19 +808,18 @@
else:
raise ValueError('Unexpected action: %s' % action)
- with TestPipeline() as p:
- # TODO(BEAM-8601): Pass this during pipeline construction.
- p.options.view_as(StandardOptions).streaming = True
+ @ptransform.ptransform_fn
+ def CheckAggregation(inputs_and_expected, aggregation):
# Split the test stream into a branch of to-be-processed elements, and
# a branch of expected results.
inputs, expected = (
- p
- | read_test_stream
+ inputs_and_expected
| beam.FlatMapTuple(
lambda tag, value: [
beam.pvalue.TaggedOutput(tag, ('key1', value)),
beam.pvalue.TaggedOutput(tag, ('key2', value)),
]).with_outputs('input', 'expect'))
+
# Process the inputs with the given windowing to produce actual outputs.
outputs = (
inputs
@@ -830,7 +830,7 @@
trigger=trigger_fn,
accumulation_mode=accumulation_mode,
timestamp_combiner=timestamp_combiner)
- | beam.GroupByKey()
+ | aggregation
| beam.MapTuple(
lambda k, vs,
window=beam.DoFn.WindowParam,
@@ -852,6 +852,21 @@
| beam.Flatten()
| beam.ParDo(Check(self.allow_out_of_order)))
+ class Concat(beam.CombineFn):
+ create_accumulator = lambda self: []
+ add_input = lambda self, acc, element: acc.append(element) or acc
+ merge_accumulators = lambda self, accs: sum(accs, [])
+ extract_output = lambda self, acc: acc
+
+ with TestPipeline() as p:
+ # TODO(BEAM-8601): Pass this during pipeline construction.
+ p.options.view_as(StandardOptions).streaming = True
+
+ # We can have at most one test stream per pipeline, so we share it.
+ inputs_and_expected = p | read_test_stream
+ _ = inputs_and_expected | CheckAggregation(beam.GroupByKey())
+ _ = inputs_and_expected | CheckAggregation(beam.CombinePerKey(Concat()))
+
class TestStreamTranscriptTest(BaseTestStreamTranscriptTest):
allow_out_of_order = False