Merge pull request #10043 [BEAM-8597] Allow TestStream trigger tests to run on other runners.
diff --git a/sdks/python/apache_beam/testing/test_stream.py b/sdks/python/apache_beam/testing/test_stream.py
index 610a1a8..9d9284c 100644
--- a/sdks/python/apache_beam/testing/test_stream.py
+++ b/sdks/python/apache_beam/testing/test_stream.py
@@ -31,6 +31,8 @@
from apache_beam import coders
from apache_beam import core
from apache_beam import pvalue
+from apache_beam.portability import common_urns
+from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.transforms import PTransform
from apache_beam.transforms import window
from apache_beam.transforms.window import TimestampedValue
@@ -66,6 +68,28 @@
# TODO(BEAM-5949): Needed for Python 2 compatibility.
return not self == other
+ @abstractmethod
+ def to_runner_api(self, element_coder):
+ raise NotImplementedError
+
+ @staticmethod
+ def from_runner_api(proto, element_coder):
+ if proto.HasField('element_event'):
+ return ElementEvent(
+ [TimestampedValue(
+ element_coder.decode(tv.encoded_element),
+ timestamp.Timestamp(micros=1000 * tv.timestamp))
+ for tv in proto.element_event.elements])
+ elif proto.HasField('watermark_event'):
+ return WatermarkEvent(timestamp.Timestamp(
+ micros=1000 * proto.watermark_event.new_watermark))
+ elif proto.HasField('processing_time_event'):
+ return ProcessingTimeEvent(timestamp.Duration(
+ micros=1000 * proto.processing_time_event.advance_duration))
+ else:
+ raise ValueError(
+ 'Unknown TestStream Event type: %s' % proto.WhichOneof('event'))
+
class ElementEvent(Event):
"""Element-producing test stream event."""
@@ -82,6 +106,15 @@
def __lt__(self, other):
return self.timestamped_values < other.timestamped_values
+ def to_runner_api(self, element_coder):
+ return beam_runner_api_pb2.TestStreamPayload.Event(
+ element_event=beam_runner_api_pb2.TestStreamPayload.Event.AddElements(
+ elements=[
+ beam_runner_api_pb2.TestStreamPayload.TimestampedElement(
+ encoded_element=element_coder.encode(tv.value),
+ timestamp=tv.timestamp.micros // 1000)
+ for tv in self.timestamped_values]))
+
class WatermarkEvent(Event):
"""Watermark-advancing test stream event."""
@@ -98,6 +131,11 @@
def __lt__(self, other):
return self.new_watermark < other.new_watermark
+ def to_runner_api(self, unused_element_coder):
+ return beam_runner_api_pb2.TestStreamPayload.Event(
+ watermark_event
+ =beam_runner_api_pb2.TestStreamPayload.Event.AdvanceWatermark(
+ new_watermark=self.new_watermark.micros // 1000))
class ProcessingTimeEvent(Event):
"""Processing time-advancing test stream event."""
@@ -114,6 +152,12 @@
def __lt__(self, other):
return self.advance_by < other.advance_by
+ def to_runner_api(self, unused_element_coder):
+ return beam_runner_api_pb2.TestStreamPayload.Event(
+ processing_time_event
+ =beam_runner_api_pb2.TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=self.advance_by.micros // 1000))
+
class TestStream(PTransform):
"""Test stream that generates events on an unbounded PCollection of elements.
@@ -123,11 +167,12 @@
output.
"""
- def __init__(self, coder=coders.FastPrimitivesCoder()):
+ def __init__(self, coder=coders.FastPrimitivesCoder(), events=()):
+ super(TestStream, self).__init__()
assert coder is not None
self.coder = coder
self.current_watermark = timestamp.MIN_TIMESTAMP
- self.events = []
+ self.events = list(events)
def get_windowing(self, unused_inputs):
return core.Windowing(window.GlobalWindows())
@@ -206,3 +251,19 @@
"""
self._add(ProcessingTimeEvent(advance_by))
return self
+
+ def to_runner_api_parameter(self, context):
+ return (
+ common_urns.primitives.TEST_STREAM.urn,
+ beam_runner_api_pb2.TestStreamPayload(
+ coder_id=context.coders.get_id(self.coder),
+ events=[e.to_runner_api(self.coder) for e in self.events]))
+
+ @PTransform.register_urn(
+ common_urns.primitives.TEST_STREAM.urn,
+ beam_runner_api_pb2.TestStreamPayload)
+ def from_runner_api_parameter(payload, context):
+ coder = context.coders.get_by_id(payload.coder_id)
+ return TestStream(
+ coder=coder,
+ events=[Event.from_runner_api(e, coder) for e in payload.events])
diff --git a/sdks/python/apache_beam/transforms/trigger_test.py b/sdks/python/apache_beam/transforms/trigger_test.py
index dbc4bcd..22ecda3 100644
--- a/sdks/python/apache_beam/transforms/trigger_test.py
+++ b/sdks/python/apache_beam/transforms/trigger_test.py
@@ -20,6 +20,7 @@
from __future__ import absolute_import
import collections
+import json
import os.path
import pickle
import unittest
@@ -31,6 +32,7 @@
import yaml
import apache_beam as beam
+from apache_beam import coders
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.runners import pipeline_context
@@ -502,7 +504,10 @@
while hasattr(cls, unique_name):
counter += 1
unique_name = 'test_%s_%d' % (name, counter)
- setattr(cls, unique_name, lambda self: self._run_log_test(spec))
+ test_method = lambda self: self._run_log_test(spec)
+ test_method.__name__ = unique_name
+ test_method.__test__ = True
+ setattr(cls, unique_name, test_method)
# We must prepend an underscore to this name so that the open-source unittest
# runner does not execute this method directly as a test.
@@ -606,24 +611,25 @@
window_fn, trigger_fn, accumulation_mode, timestamp_combiner,
transcript, spec)
- def _windowed_value_info(self, windowed_value):
- # Currently some runners operate at the millisecond level, and some at the
- # microsecond level. Trigger transcript timestamps are expressed as
- # integral units of the finest granularity, whatever that may be.
- # In these tests we interpret them as integral seconds and then truncate
- # the results to integral seconds to allow for portability across
- # different sub-second resolutions.
- window, = windowed_value.windows
- return {
- 'window': [int(window.start), int(window.max_timestamp())],
- 'values': sorted(windowed_value.value),
- 'timestamp': int(windowed_value.timestamp),
- 'index': windowed_value.pane_info.index,
- 'nonspeculative_index': windowed_value.pane_info.nonspeculative_index,
- 'early': windowed_value.pane_info.timing == PaneInfoTiming.EARLY,
- 'late': windowed_value.pane_info.timing == PaneInfoTiming.LATE,
- 'final': windowed_value.pane_info.is_last,
- }
+
+def _windowed_value_info(windowed_value):
+ # Currently some runners operate at the millisecond level, and some at the
+ # microsecond level. Trigger transcript timestamps are expressed as
+ # integral units of the finest granularity, whatever that may be.
+ # In these tests we interpret them as integral seconds and then truncate
+ # the results to integral seconds to allow for portability across
+ # different sub-second resolutions.
+ window, = windowed_value.windows
+ return {
+ 'window': [int(window.start), int(window.max_timestamp())],
+ 'values': sorted(windowed_value.value),
+ 'timestamp': int(windowed_value.timestamp),
+ 'index': windowed_value.pane_info.index,
+ 'nonspeculative_index': windowed_value.pane_info.nonspeculative_index,
+ 'early': windowed_value.pane_info.timing == PaneInfoTiming.EARLY,
+ 'late': windowed_value.pane_info.timing == PaneInfoTiming.LATE,
+ 'final': windowed_value.pane_info.is_last,
+ }
class TriggerDriverTranscriptTest(TranscriptTest):
@@ -645,7 +651,7 @@
for timer_window, (name, time_domain, t_timestamp) in to_fire:
for wvalue in driver.process_timer(
timer_window, name, time_domain, t_timestamp, state):
- output.append(self._windowed_value_info(wvalue))
+ output.append(_windowed_value_info(wvalue))
to_fire = state.get_and_clear_timers(watermark)
for action, params in transcript:
@@ -661,7 +667,7 @@
WindowedValue(t, t, window_fn.assign(WindowFn.AssignContext(t, t)))
for t in params]
output = [
- self._windowed_value_info(wv)
+ _windowed_value_info(wv)
for wv in driver.process_elements(state, bundle, watermark)]
fire_timers()
@@ -690,7 +696,7 @@
self.assertEqual([], output, msg='Unexpected output: %s' % output)
-class TestStreamTranscriptTest(TranscriptTest):
+class BaseTestStreamTranscriptTest(TranscriptTest):
"""A suite of TestStream-based tests based on trigger transcript entries.
"""
@@ -702,14 +708,17 @@
if runner_name in spec.get('broken_on', ()):
self.skipTest('Known to be broken on %s' % runner_name)
- test_stream = TestStream()
+ # Elements are encoded as a json strings to allow other languages to
+ # decode elements while executing the test stream.
+ # TODO(BEAM-8600): Eliminate these gymnastics.
+ test_stream = TestStream(coder=coders.StrUtf8Coder()).with_output_types(str)
for action, params in transcript:
if action == 'expect':
- test_stream.add_elements([('expect', params)])
+ test_stream.add_elements([json.dumps(('expect', params))])
else:
- test_stream.add_elements([('expect', [])])
+ test_stream.add_elements([json.dumps(('expect', []))])
if action == 'input':
- test_stream.add_elements([('input', e) for e in params])
+ test_stream.add_elements([json.dumps(('input', e)) for e in params])
elif action == 'watermark':
test_stream.advance_watermark_to(params)
elif action == 'clock':
@@ -718,7 +727,9 @@
pass # Requires inspection of implementation details.
else:
raise ValueError('Unexpected action: %s' % action)
- test_stream.add_elements([('expect', [])])
+ test_stream.add_elements([json.dumps(('expect', []))])
+
+ read_test_stream = test_stream | beam.Map(json.loads)
class Check(beam.DoFn):
"""A StatefulDoFn that verifies outputs are produced as expected.
@@ -731,12 +742,40 @@
The key is ignored, but all items must be on the same key to share state.
"""
+ def __init__(self, allow_out_of_order=True):
+ # Some runners don't support cross-stage TestStream semantics.
+ self.allow_out_of_order = allow_out_of_order
+
def process(
- self, element, seen=beam.DoFn.StateParam(
+ self,
+ element,
+ seen=beam.DoFn.StateParam(
beam.transforms.userstate.BagStateSpec(
'seen',
+ beam.coders.FastPrimitivesCoder())),
+ expected=beam.DoFn.StateParam(
+ beam.transforms.userstate.BagStateSpec(
+ 'expected',
beam.coders.FastPrimitivesCoder()))):
_, (action, data) = element
+
+ if self.allow_out_of_order:
+ if action == 'expect' and not list(seen.read()):
+ if data:
+ expected.add(data)
+ return
+ elif action == 'actual' and list(expected.read()):
+ seen.add(data)
+ all_data = list(seen.read())
+ all_expected = list(expected.read())
+ if len(all_data) == len(all_expected[0]):
+ expected.clear()
+ for expect in all_expected[1:]:
+ expected.add(expect)
+ action, data = 'expect', all_expected[0]
+ else:
+ return
+
if action == 'actual':
seen.add(data)
@@ -768,12 +807,14 @@
else:
raise ValueError('Unexpected action: %s' % action)
- with TestPipeline(options=PipelineOptions(streaming=True)) as p:
+ with TestPipeline() as p:
+ # TODO(BEAM-8601): Pass this during pipeline construction.
+ p.options.view_as(StandardOptions).streaming = True
# Split the test stream into a branch of to-be-processed elements, and
# a branch of expected results.
inputs, expected = (
p
- | test_stream
+ | read_test_stream
| beam.MapTuple(
lambda tag, value: beam.pvalue.TaggedOutput(tag, ('key', value))
).with_outputs('input', 'expect'))
@@ -794,7 +835,7 @@
t=beam.DoFn.TimestampParam,
p=beam.DoFn.PaneInfoParam: (
k,
- self._windowed_value_info(WindowedValue(
+ _windowed_value_info(WindowedValue(
vs, windows=[window], timestamp=t, pane_info=p))))
# Place outputs back into the global window to allow flattening
# and share a single state in Check.
@@ -805,7 +846,17 @@
tagged_outputs = (
outputs | beam.MapTuple(lambda key, value: (key, ('actual', value))))
# pylint: disable=expression-not-assigned
- (tagged_expected, tagged_outputs) | beam.Flatten() | beam.ParDo(Check())
+ ([tagged_expected, tagged_outputs]
+ | beam.Flatten()
+ | beam.ParDo(Check(self.allow_out_of_order)))
+
+
+class TestStreamTranscriptTest(BaseTestStreamTranscriptTest):
+ allow_out_of_order = False
+
+
+class WeakTestStreamTranscriptTest(BaseTestStreamTranscriptTest):
+ allow_out_of_order = True
TRANSCRIPT_TEST_FILE = os.path.join(
@@ -814,6 +865,7 @@
if os.path.exists(TRANSCRIPT_TEST_FILE):
TriggerDriverTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
TestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
+ WeakTestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
if __name__ == '__main__':
diff --git a/sdks/python/test-suites/portable/common.gradle b/sdks/python/test-suites/portable/common.gradle
index f04d28d..3cb4362 100644
--- a/sdks/python/test-suites/portable/common.gradle
+++ b/sdks/python/test-suites/portable/common.gradle
@@ -79,3 +79,22 @@
task flinkValidatesRunner() {
dependsOn 'flinkCompatibilityMatrixLoopback'
}
+
+// TODO(BEAM-8598): Enable on pre-commit.
+task flinkTriggerTranscript() {
+ dependsOn 'setupVirtualenv'
+ dependsOn ':runners:flink:1.9:job-server:shadowJar'
+ doLast {
+ exec {
+ executable 'sh'
+ args '-c', """
+ . ${envdir}/bin/activate \\
+ && cd ${pythonRootDir} \\
+ && pip install -e .[test] \\
+ && python setup.py nosetests \\
+ --tests apache_beam.transforms.trigger_test:WeakTestStreamTranscriptTest \\
+ --test-pipeline-options='--runner=FlinkRunner --environment_type=LOOPBACK --flink_job_server_jar=${project(":runners:flink:1.9:job-server:").shadowJar.archivePath}'
+ """
+ }
+ }
+}