Merge pull request #10149 [BEAM-8739] Consistently use with Pipeline(...) syntax
diff --git a/sdks/python/apache_beam/examples/cookbook/filters_test.py b/sdks/python/apache_beam/examples/cookbook/filters_test.py
index c0d8e12..eebbc10 100644
--- a/sdks/python/apache_beam/examples/cookbook/filters_test.py
+++ b/sdks/python/apache_beam/examples/cookbook/filters_test.py
@@ -41,33 +41,31 @@
{'year': 2011, 'month': 3, 'day': 3, 'mean_temp': 5, 'removed': 'a'},
]
- def _get_result_for_month(self, month):
- p = TestPipeline()
- rows = (p | 'create' >> beam.Create(self.input_data))
-
+ def _get_result_for_month(self, pipeline, month):
+ rows = (pipeline | 'create' >> beam.Create(self.input_data))
results = filters.filter_cold_days(rows, month)
return results
def test_basic(self):
"""Test that the correct result is returned for a simple dataset."""
- results = self._get_result_for_month(1)
- assert_that(
- results,
- equal_to([{'year': 2010, 'month': 1, 'day': 1, 'mean_temp': 3},
- {'year': 2012, 'month': 1, 'day': 2, 'mean_temp': 3}]))
- results.pipeline.run()
+ with TestPipeline() as p:
+ results = self._get_result_for_month(p, 1)
+ assert_that(
+ results,
+ equal_to([{'year': 2010, 'month': 1, 'day': 1, 'mean_temp': 3},
+ {'year': 2012, 'month': 1, 'day': 2, 'mean_temp': 3}]))
def test_basic_empty(self):
"""Test that the correct empty result is returned for a simple dataset."""
- results = self._get_result_for_month(3)
- assert_that(results, equal_to([]))
- results.pipeline.run()
+ with TestPipeline() as p:
+ results = self._get_result_for_month(p, 3)
+ assert_that(results, equal_to([]))
def test_basic_empty_missing(self):
"""Test that the correct empty result is returned for a missing month."""
- results = self._get_result_for_month(4)
- assert_that(results, equal_to([]))
- results.pipeline.run()
+ with TestPipeline() as p:
+ results = self._get_result_for_month(p, 4)
+ assert_that(results, equal_to([]))
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/examples/fastavro_it_test.py b/sdks/python/apache_beam/examples/fastavro_it_test.py
index 5b1a3c5..35afa52 100644
--- a/sdks/python/apache_beam/examples/fastavro_it_test.py
+++ b/sdks/python/apache_beam/examples/fastavro_it_test.py
@@ -163,43 +163,42 @@
result.wait_until_finish()
assert result.state == PipelineState.DONE
- fastavro_read_pipeline = TestPipeline(is_integration_test=True)
+ with TestPipeline(is_integration_test=True) as fastavro_read_pipeline:
- fastavro_records = \
- fastavro_read_pipeline \
- | 'create-fastavro' >> Create(['%s*' % fastavro_output]) \
- | 'read-fastavro' >> ReadAllFromAvro(use_fastavro=True) \
- | Map(lambda rec: (rec['number'], rec))
+ fastavro_records = \
+ fastavro_read_pipeline \
+ | 'create-fastavro' >> Create(['%s*' % fastavro_output]) \
+ | 'read-fastavro' >> ReadAllFromAvro(use_fastavro=True) \
+ | Map(lambda rec: (rec['number'], rec))
- avro_records = \
- fastavro_read_pipeline \
- | 'create-avro' >> Create(['%s*' % avro_output]) \
- | 'read-avro' >> ReadAllFromAvro(use_fastavro=False) \
- | Map(lambda rec: (rec['number'], rec))
+ avro_records = \
+ fastavro_read_pipeline \
+ | 'create-avro' >> Create(['%s*' % avro_output]) \
+ | 'read-avro' >> ReadAllFromAvro(use_fastavro=False) \
+ | Map(lambda rec: (rec['number'], rec))
- def check(elem):
- v = elem[1]
+ def check(elem):
+ v = elem[1]
- def assertEqual(l, r):
- if l != r:
- raise BeamAssertException('Assertion failed: %s == %s' % (l, r))
+ def assertEqual(l, r):
+ if l != r:
+ raise BeamAssertException('Assertion failed: %s == %s' % (l, r))
- assertEqual(v.keys(), ['avro', 'fastavro'])
- avro_values = v['avro']
- fastavro_values = v['fastavro']
- assertEqual(avro_values, fastavro_values)
- assertEqual(len(avro_values), 1)
+ assertEqual(v.keys(), ['avro', 'fastavro'])
+ avro_values = v['avro']
+ fastavro_values = v['fastavro']
+ assertEqual(avro_values, fastavro_values)
+ assertEqual(len(avro_values), 1)
- # pylint: disable=expression-not-assigned
- {
- 'avro': avro_records,
- 'fastavro': fastavro_records
- } \
- | CoGroupByKey() \
- | Map(check)
+ # pylint: disable=expression-not-assigned
+ {
+ 'avro': avro_records,
+ 'fastavro': fastavro_records
+ } \
+ | CoGroupByKey() \
+ | Map(check)
- self.addCleanup(delete_files, [self.output])
- fastavro_read_pipeline.run().wait_until_finish()
+ self.addCleanup(delete_files, [self.output])
assert result.state == PipelineState.DONE
diff --git a/sdks/python/apache_beam/examples/flink/flink_streaming_impulse.py b/sdks/python/apache_beam/examples/flink/flink_streaming_impulse.py
index 24ca510..47c3836 100644
--- a/sdks/python/apache_beam/examples/flink/flink_streaming_impulse.py
+++ b/sdks/python/apache_beam/examples/flink/flink_streaming_impulse.py
@@ -75,24 +75,22 @@
pipeline_options = PipelineOptions(pipeline_args)
- p = beam.Pipeline(options=pipeline_options)
+ with beam.Pipeline(options=pipeline_options) as p:
- messages = (p | FlinkStreamingImpulseSource()
- .set_message_count(known_args.count)
- .set_interval_ms(known_args.interval_ms))
+ messages = (p | FlinkStreamingImpulseSource()
+ .set_message_count(known_args.count)
+ .set_interval_ms(known_args.interval_ms))
- _ = (messages | 'decode' >> beam.Map(lambda x: ('', 1))
- | 'window' >> beam.WindowInto(window.GlobalWindows(),
- trigger=Repeatedly(
- AfterProcessingTime(5 * 1000)),
- accumulation_mode=
- AccumulationMode.DISCARDING)
- | 'group' >> beam.GroupByKey()
- | 'count' >> beam.Map(count)
- | 'log' >> beam.Map(lambda x: logging.info("%d" % x[1])))
+ _ = (messages | 'decode' >> beam.Map(lambda x: ('', 1))
+ | 'window' >> beam.WindowInto(window.GlobalWindows(),
+ trigger=Repeatedly(
+ AfterProcessingTime(5 * 1000)),
+ accumulation_mode=
+ AccumulationMode.DISCARDING)
+ | 'group' >> beam.GroupByKey()
+ | 'count' >> beam.Map(count)
+ | 'log' >> beam.Map(lambda x: logging.info("%d" % x[1])))
- result = p.run()
- result.wait_until_finish()
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/examples/snippets/snippets.py b/sdks/python/apache_beam/examples/snippets/snippets.py
index 76521c5de..aaa99e9 100644
--- a/sdks/python/apache_beam/examples/snippets/snippets.py
+++ b/sdks/python/apache_beam/examples/snippets/snippets.py
@@ -111,31 +111,29 @@
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
- p = beam.Pipeline(options=PipelineOptions())
- # [END pipelines_constructing_creating]
+ with beam.Pipeline(options=PipelineOptions()) as p:
+ pass # build your pipeline here
+ # [END pipelines_constructing_creating]
- p = TestPipeline() # Use TestPipeline for testing.
+ with TestPipeline() as p: # Use TestPipeline for testing.
+ # pylint: disable=line-too-long
- # [START pipelines_constructing_reading]
- lines = p | 'ReadMyFile' >> beam.io.ReadFromText('gs://some/inputData.txt')
- # [END pipelines_constructing_reading]
+ # [START pipelines_constructing_reading]
+ lines = p | 'ReadMyFile' >> beam.io.ReadFromText('gs://some/inputData.txt')
+ # [END pipelines_constructing_reading]
- # [START pipelines_constructing_applying]
- words = lines | beam.FlatMap(lambda x: re.findall(r'[A-Za-z\']+', x))
- reversed_words = words | ReverseWords()
- # [END pipelines_constructing_applying]
+ # [START pipelines_constructing_applying]
+ words = lines | beam.FlatMap(lambda x: re.findall(r'[A-Za-z\']+', x))
+ reversed_words = words | ReverseWords()
+ # [END pipelines_constructing_applying]
- # [START pipelines_constructing_writing]
- filtered_words = reversed_words | 'FilterWords' >> beam.Filter(filter_words)
- filtered_words | 'WriteMyFile' >> beam.io.WriteToText(
- 'gs://some/outputData.txt')
- # [END pipelines_constructing_writing]
+ # [START pipelines_constructing_writing]
+ filtered_words = reversed_words | 'FilterWords' >> beam.Filter(filter_words)
+ filtered_words | 'WriteMyFile' >> beam.io.WriteToText(
+ 'gs://some/outputData.txt')
+ # [END pipelines_constructing_writing]
- p.visit(SnippetUtils.RenameFiles(renames))
-
- # [START pipelines_constructing_running]
- p.run()
- # [END pipelines_constructing_running]
+ p.visit(SnippetUtils.RenameFiles(renames))
def model_pipelines(argv):
@@ -249,12 +247,10 @@
my_input = my_options.input
my_output = my_options.output
- p = TestPipeline() # Use TestPipeline for testing.
+ with TestPipeline() as p: # Use TestPipeline for testing.
- lines = p | beam.io.ReadFromText(my_input)
- lines | beam.io.WriteToText(my_output)
-
- p.run()
+ lines = p | beam.io.ReadFromText(my_input)
+ lines | beam.io.WriteToText(my_output)
def pipeline_options_local(argv):
@@ -286,13 +282,12 @@
# [START pipeline_options_local]
# Create and set your Pipeline Options.
options = PipelineOptions()
- p = Pipeline(options=options)
- # [END pipeline_options_local]
+ with Pipeline(options=options) as p:
+ # [END pipeline_options_local]
- p = TestPipeline() # Use TestPipeline for testing.
- lines = p | beam.io.ReadFromText(my_input)
- lines | beam.io.WriteToText(my_output)
- p.run()
+ with TestPipeline() as p: # Use TestPipeline for testing.
+ lines = p | beam.io.ReadFromText(my_input)
+ lines | beam.io.WriteToText(my_output)
def pipeline_options_command_line(argv):
@@ -541,30 +536,28 @@
required=True,
help='Output file to write results to.')
pipeline_options = PipelineOptions(['--output', 'some/output_path'])
- p = beam.Pipeline(options=pipeline_options)
+ with beam.Pipeline(options=pipeline_options) as p:
- wordcount_options = pipeline_options.view_as(WordcountTemplatedOptions)
- lines = p | 'Read' >> ReadFromText(wordcount_options.input)
- # [END example_wordcount_templated]
+ wordcount_options = pipeline_options.view_as(WordcountTemplatedOptions)
+ lines = p | 'Read' >> ReadFromText(wordcount_options.input)
+ # [END example_wordcount_templated]
- def format_result(word_count):
- (word, count) = word_count
- return '%s: %s' % (word, count)
+ def format_result(word_count):
+ (word, count) = word_count
+ return '%s: %s' % (word, count)
- (
- lines
- | 'ExtractWords' >> beam.FlatMap(
- lambda x: re.findall(r'[A-Za-z\']+', x))
- | 'PairWithOnes' >> beam.Map(lambda x: (x, 1))
- | 'Group' >> beam.GroupByKey()
- | 'Sum' >> beam.Map(lambda word_ones: (word_ones[0], sum(word_ones[1])))
- | 'Format' >> beam.Map(format_result)
- | 'Write' >> WriteToText(wordcount_options.output)
- )
+ (
+ lines
+ | 'ExtractWords' >> beam.FlatMap(
+ lambda x: re.findall(r'[A-Za-z\']+', x))
+ | 'PairWithOnes' >> beam.Map(lambda x: (x, 1))
+ | 'Group' >> beam.GroupByKey()
+ | 'Sum' >> beam.Map(lambda word_ones: (word_ones[0], sum(word_ones[1])))
+ | 'Format' >> beam.Map(format_result)
+ | 'Write' >> WriteToText(wordcount_options.output)
+ )
- p.visit(SnippetUtils.RenameFiles(renames))
- result = p.run()
- result.wait_until_finish()
+ p.visit(SnippetUtils.RenameFiles(renames))
def examples_wordcount_debugging(renames):
@@ -713,25 +706,23 @@
yield self.templated_int.get() + an_int
pipeline_options = PipelineOptions()
- p = beam.Pipeline(options=pipeline_options)
+ with beam.Pipeline(options=pipeline_options) as p:
- user_options = pipeline_options.view_as(TemplatedUserOptions)
- my_sum_fn = MySumFn(user_options.templated_int)
- sum = (p
- | 'ReadCollection' >> beam.io.ReadFromText(
- 'gs://some/integer_collection')
- | 'StringToInt' >> beam.Map(lambda w: int(w))
- | 'AddGivenInt' >> beam.ParDo(my_sum_fn)
- | 'WriteResultingCollection' >> WriteToText('some/output_path'))
- # [END examples_ptransforms_templated]
+ user_options = pipeline_options.view_as(TemplatedUserOptions)
+ my_sum_fn = MySumFn(user_options.templated_int)
+ sum = (p
+ | 'ReadCollection' >> beam.io.ReadFromText(
+ 'gs://some/integer_collection')
+ | 'StringToInt' >> beam.Map(lambda w: int(w))
+ | 'AddGivenInt' >> beam.ParDo(my_sum_fn)
+ | 'WriteResultingCollection' >> WriteToText('some/output_path'))
+ # [END examples_ptransforms_templated]
- # Templates are not supported by DirectRunner (only by DataflowRunner)
- # so a value must be provided at graph-construction time
- my_sum_fn.templated_int = StaticValueProvider(int, 10)
+ # Templates are not supported by DirectRunner (only by DataflowRunner)
+ # so a value must be provided at graph-construction time
+ my_sum_fn.templated_int = StaticValueProvider(int, 10)
- p.visit(SnippetUtils.RenameFiles(renames))
- result = p.run()
- result.wait_until_finish()
+ p.visit(SnippetUtils.RenameFiles(renames))
# Defining a new source.
@@ -835,16 +826,15 @@
['line ' + str(number) for number in range(0, count)]))
# [START model_custom_source_use_ptransform]
- p = beam.Pipeline(options=PipelineOptions())
- numbers = p | 'ProduceNumbers' >> ReadFromCountingSource(count)
- # [END model_custom_source_use_ptransform]
+ with beam.Pipeline(options=PipelineOptions()) as p:
+ numbers = p | 'ProduceNumbers' >> ReadFromCountingSource(count)
+ # [END model_custom_source_use_ptransform]
- lines = numbers | beam.core.Map(lambda number: 'line %d' % number)
- assert_that(
- lines, equal_to(
- ['line ' + str(number) for number in range(0, count)]))
+ lines = numbers | beam.core.Map(lambda number: 'line %d' % number)
+ assert_that(
+ lines, equal_to(
+ ['line ' + str(number) for number in range(0, count)]))
- p.run().wait_until_finish()
# Defining the new sink.
@@ -1402,20 +1392,19 @@
pipeline_options = PipelineOptions()
# Create pipeline.
- p = beam.Pipeline(options=pipeline_options)
+ with beam.Pipeline(options=pipeline_options) as p:
- my_options = pipeline_options.view_as(MyOptions)
- # Add a branch for logging the ValueProvider value.
- _ = (p
- | beam.Create([None])
- | 'LogValueProvs' >> beam.ParDo(
- LogValueProvidersFn(my_options.string_value)))
+ my_options = pipeline_options.view_as(MyOptions)
+ # Add a branch for logging the ValueProvider value.
+ _ = (p
+ | beam.Create([None])
+ | 'LogValueProvs' >> beam.ParDo(
+ LogValueProvidersFn(my_options.string_value)))
- # The main pipeline.
- result_pc = (p
- | "main_pc" >> beam.Create([1, 2, 3])
- | beam.combiners.Sum.Globally())
+ # The main pipeline.
+ result_pc = (p
+ | "main_pc" >> beam.Create([1, 2, 3])
+ | beam.combiners.Sum.Globally())
- p.run().wait_until_finish()
# [END AccessingValueProviderInfoAfterRunSnip1]
diff --git a/sdks/python/apache_beam/examples/streaming_wordcount.py b/sdks/python/apache_beam/examples/streaming_wordcount.py
index 461e073..cdfb6a1 100644
--- a/sdks/python/apache_beam/examples/streaming_wordcount.py
+++ b/sdks/python/apache_beam/examples/streaming_wordcount.py
@@ -58,50 +58,48 @@
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
pipeline_options.view_as(StandardOptions).streaming = True
- p = beam.Pipeline(options=pipeline_options)
+ with beam.Pipeline(options=pipeline_options) as p:
- # Read from PubSub into a PCollection.
- if known_args.input_subscription:
- messages = (p
- | beam.io.ReadFromPubSub(
- subscription=known_args.input_subscription)
- .with_output_types(bytes))
- else:
- messages = (p
- | beam.io.ReadFromPubSub(topic=known_args.input_topic)
- .with_output_types(bytes))
+ # Read from PubSub into a PCollection.
+ if known_args.input_subscription:
+ messages = (p
+ | beam.io.ReadFromPubSub(
+ subscription=known_args.input_subscription)
+ .with_output_types(bytes))
+ else:
+ messages = (p
+ | beam.io.ReadFromPubSub(topic=known_args.input_topic)
+ .with_output_types(bytes))
- lines = messages | 'decode' >> beam.Map(lambda x: x.decode('utf-8'))
+ lines = messages | 'decode' >> beam.Map(lambda x: x.decode('utf-8'))
- # Count the occurrences of each word.
- def count_ones(word_ones):
- (word, ones) = word_ones
- return (word, sum(ones))
+ # Count the occurrences of each word.
+ def count_ones(word_ones):
+ (word, ones) = word_ones
+ return (word, sum(ones))
- counts = (lines
- | 'split' >> (beam.ParDo(WordExtractingDoFn())
- .with_output_types(unicode))
- | 'pair_with_one' >> beam.Map(lambda x: (x, 1))
- | beam.WindowInto(window.FixedWindows(15, 0))
- | 'group' >> beam.GroupByKey()
- | 'count' >> beam.Map(count_ones))
+ counts = (lines
+ | 'split' >> (beam.ParDo(WordExtractingDoFn())
+ .with_output_types(unicode))
+ | 'pair_with_one' >> beam.Map(lambda x: (x, 1))
+ | beam.WindowInto(window.FixedWindows(15, 0))
+ | 'group' >> beam.GroupByKey()
+ | 'count' >> beam.Map(count_ones))
- # Format the counts into a PCollection of strings.
- def format_result(word_count):
- (word, count) = word_count
- return '%s: %d' % (word, count)
+ # Format the counts into a PCollection of strings.
+ def format_result(word_count):
+ (word, count) = word_count
+ return '%s: %d' % (word, count)
- output = (counts
- | 'format' >> beam.Map(format_result)
- | 'encode' >> beam.Map(lambda x: x.encode('utf-8'))
- .with_output_types(bytes))
+ output = (counts
+ | 'format' >> beam.Map(format_result)
+ | 'encode' >> beam.Map(lambda x: x.encode('utf-8'))
+ .with_output_types(bytes))
- # Write to PubSub.
- # pylint: disable=expression-not-assigned
- output | beam.io.WriteToPubSub(known_args.output_topic)
+ # Write to PubSub.
+ # pylint: disable=expression-not-assigned
+ output | beam.io.WriteToPubSub(known_args.output_topic)
- result = p.run()
- result.wait_until_finish()
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/examples/streaming_wordcount_debugging.py b/sdks/python/apache_beam/examples/streaming_wordcount_debugging.py
index db5304d..79eecea 100644
--- a/sdks/python/apache_beam/examples/streaming_wordcount_debugging.py
+++ b/sdks/python/apache_beam/examples/streaming_wordcount_debugging.py
@@ -103,82 +103,80 @@
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = True
pipeline_options.view_as(StandardOptions).streaming = True
- p = beam.Pipeline(options=pipeline_options)
+ with beam.Pipeline(options=pipeline_options) as p:
- # Read from PubSub into a PCollection.
- if known_args.input_subscription:
- lines = p | beam.io.ReadFromPubSub(
- subscription=known_args.input_subscription)
- else:
- lines = p | beam.io.ReadFromPubSub(topic=known_args.input_topic)
+ # Read from PubSub into a PCollection.
+ if known_args.input_subscription:
+ lines = p | beam.io.ReadFromPubSub(
+ subscription=known_args.input_subscription)
+ else:
+ lines = p | beam.io.ReadFromPubSub(topic=known_args.input_topic)
- # Count the occurrences of each word.
- def count_ones(word_ones):
- (word, ones) = word_ones
- return (word, sum(ones))
+ # Count the occurrences of each word.
+ def count_ones(word_ones):
+ (word, ones) = word_ones
+ return (word, sum(ones))
- counts = (lines
- | 'AddTimestampFn' >> beam.ParDo(AddTimestampFn())
- | 'After AddTimestampFn' >> ParDo(PrintFn('After AddTimestampFn'))
- | 'Split' >> (beam.ParDo(WordExtractingDoFn())
- .with_output_types(unicode))
- | 'PairWithOne' >> beam.Map(lambda x: (x, 1))
- | beam.WindowInto(window.FixedWindows(5, 0))
- | 'GroupByKey' >> beam.GroupByKey()
- | 'CountOnes' >> beam.Map(count_ones))
+ counts = (lines
+ | 'AddTimestampFn' >> beam.ParDo(AddTimestampFn())
+ | 'After AddTimestampFn' >> ParDo(PrintFn('After AddTimestampFn'))
+ | 'Split' >> (beam.ParDo(WordExtractingDoFn())
+ .with_output_types(unicode))
+ | 'PairWithOne' >> beam.Map(lambda x: (x, 1))
+ | beam.WindowInto(window.FixedWindows(5, 0))
+ | 'GroupByKey' >> beam.GroupByKey()
+ | 'CountOnes' >> beam.Map(count_ones))
- # Format the counts into a PCollection of strings.
- def format_result(word_count):
- (word, count) = word_count
- return '%s: %d' % (word, count)
+ # Format the counts into a PCollection of strings.
+ def format_result(word_count):
+ (word, count) = word_count
+ return '%s: %d' % (word, count)
- output = counts | 'format' >> beam.Map(format_result)
+ output = counts | 'format' >> beam.Map(format_result)
- # Write to PubSub.
- # pylint: disable=expression-not-assigned
- output | beam.io.WriteStringsToPubSub(known_args.output_topic)
+ # Write to PubSub.
+ # pylint: disable=expression-not-assigned
+ output | beam.io.WriteStringsToPubSub(known_args.output_topic)
- def check_gbk_format():
- # A matcher that checks that the output of GBK is of the form word: count.
- def matcher(elements):
- # pylint: disable=unused-variable
- actual_elements_in_window, window = elements
- for elm in actual_elements_in_window:
- assert re.match(r'\S+:\s+\d+', elm) is not None
- return matcher
+ def check_gbk_format():
+ # A matcher that checks that the output of GBK is of the form word: count.
+ def matcher(elements):
+ # pylint: disable=unused-variable
+ actual_elements_in_window, window = elements
+ for elm in actual_elements_in_window:
+ assert re.match(r'\S+:\s+\d+', elm) is not None
+ return matcher
- # Check that the format of the output is correct.
- assert_that(
- output,
- check_gbk_format(),
- use_global_window=False,
- label='Assert word:count format.')
+ # Check that the format of the output is correct.
+ assert_that(
+ output,
+ check_gbk_format(),
+ use_global_window=False,
+ label='Assert word:count format.')
- # Check also that elements are ouput in the right window.
- # This expects exactly 1 occurrence of any subset of the elements
- # 150, 151, 152, 153, 154 in the window [150, 155)
- # or exactly 1 occurrence of any subset of the elements
- # 210, 211, 212, 213, 214 in the window [210, 215).
- expected_window_to_elements = {
- window.IntervalWindow(150, 155): [
- ('150: 1'), ('151: 1'), ('152: 1'), ('153: 1'), ('154: 1'),
- ],
- window.IntervalWindow(210, 215): [
- ('210: 1'), ('211: 1'), ('212: 1'), ('213: 1'), ('214: 1'),
- ],
- }
+ # Check also that elements are ouput in the right window.
+ # This expects exactly 1 occurrence of any subset of the elements
+ # 150, 151, 152, 153, 154 in the window [150, 155)
+ # or exactly 1 occurrence of any subset of the elements
+ # 210, 211, 212, 213, 214 in the window [210, 215).
+ expected_window_to_elements = {
+ window.IntervalWindow(150, 155): [
+ ('150: 1'), ('151: 1'), ('152: 1'), ('153: 1'), ('154: 1'),
+ ],
+ window.IntervalWindow(210, 215): [
+ ('210: 1'), ('211: 1'), ('212: 1'), ('213: 1'), ('214: 1'),
+ ],
+ }
- # To make it pass, publish numbers in [150-155) or [210-215) with no repeats.
- # To make it fail, publish a repeated number in the range above range.
- # For example: '210 213 151 213'
- assert_that(
- output,
- equal_to_per_window(expected_window_to_elements),
- use_global_window=False,
- label='Assert correct streaming windowing.')
+ # To pass, publish numbers in [150-155) or [210-215) with no repeats.
+ # To fail, publish a repeated number in the range above range.
+ # For example: '210 213 151 213'
+ assert_that(
+ output,
+ equal_to_per_window(expected_window_to_elements),
+ use_global_window=False,
+ label='Assert correct streaming windowing.')
- result = p.run()
- result.wait_until_finish()
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/examples/wordcount_xlang.py b/sdks/python/apache_beam/examples/wordcount_xlang.py
index 5cb3303..b8353bb 100644
--- a/sdks/python/apache_beam/examples/wordcount_xlang.py
+++ b/sdks/python/apache_beam/examples/wordcount_xlang.py
@@ -59,7 +59,7 @@
return re.findall(r'[\w\']+', text_line)
-def run(p, input_file, output_file):
+def build_pipeline(p, input_file, output_file):
# Read the text file[pattern] into a PCollection.
lines = p | 'read' >> ReadFromText(input_file)
@@ -80,9 +80,6 @@
# pylint: disable=expression-not-assigned
output | 'write' >> WriteToText(output_file)
- result = p.run()
- result.wait_until_finish()
-
def main():
logging.getLogger().setLevel(logging.INFO)
@@ -112,10 +109,6 @@
# workflow rely on global context (e.g., a module imported at module level).
pipeline_options.view_as(SetupOptions).save_main_session = True
- p = beam.Pipeline(options=pipeline_options)
- # Preemptively start due to BEAM-6666.
- p.runner.create_job_service(pipeline_options)
-
try:
server = subprocess.Popen([
'java', '-jar', known_args.expansion_service_jar,
@@ -124,7 +117,11 @@
with grpc.insecure_channel(EXPANSION_SERVICE_ADDR) as channel:
grpc.channel_ready_future(channel).result()
- run(p, known_args.input, known_args.output)
+ with beam.Pipeline(options=pipeline_options) as p:
+ # Preemptively start due to BEAM-6666.
+ p.runner.create_job_service(pipeline_options)
+
+ build_pipeline(p, known_args.input, known_args.output)
finally:
server.kill()
diff --git a/sdks/python/apache_beam/io/concat_source_test.py b/sdks/python/apache_beam/io/concat_source_test.py
index eea44e0..41e4e63 100644
--- a/sdks/python/apache_beam/io/concat_source_test.py
+++ b/sdks/python/apache_beam/io/concat_source_test.py
@@ -225,11 +225,10 @@
RangeSource(10, 100),
RangeSource(100, 1000),
])
- pipeline = TestPipeline()
- pcoll = pipeline | beam.io.Read(source)
- assert_that(pcoll, equal_to(list(range(1000))))
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | beam.io.Read(source)
+ assert_that(pcoll, equal_to(list(range(1000))))
- pipeline.run()
def test_conact_source_exhaustive(self):
source = ConcatSource([RangeSource(0, 10),
diff --git a/sdks/python/apache_beam/io/external/generate_sequence_test.py b/sdks/python/apache_beam/io/external/generate_sequence_test.py
index 652e47b..060ce28 100644
--- a/sdks/python/apache_beam/io/external/generate_sequence_test.py
+++ b/sdks/python/apache_beam/io/external/generate_sequence_test.py
@@ -41,12 +41,11 @@
"EXPANSION_PORT environment var is not provided.")
class XlangGenerateSequenceTest(unittest.TestCase):
def test_generate_sequence(self):
- test_pipeline = TestPipeline()
port = os.environ.get('EXPANSION_PORT')
address = 'localhost:%s' % port
try:
- with test_pipeline as p:
+ with TestPipeline() as p:
res = (
p
| GenerateSequence(start=1, stop=10,
diff --git a/sdks/python/apache_beam/io/external/xlang_parquetio_test.py b/sdks/python/apache_beam/io/external/xlang_parquetio_test.py
index ed49a58..aee35a0 100644
--- a/sdks/python/apache_beam/io/external/xlang_parquetio_test.py
+++ b/sdks/python/apache_beam/io/external/xlang_parquetio_test.py
@@ -53,11 +53,10 @@
port = os.environ.get('EXPANSION_PORT')
address = 'localhost:%s' % port
try:
- test_pipeline = TestPipeline()
- test_pipeline.get_pipeline_options().view_as(
- DebugOptions).experiments.append('jar_packages='+expansion_jar)
- test_pipeline.not_use_test_runner_api = True
- with test_pipeline as p:
+ with TestPipeline() as p:
+ p.get_pipeline_options().view_as(
+ DebugOptions).experiments.append('jar_packages='+expansion_jar)
+ p.not_use_test_runner_api = True
_ = p \
| beam.Create([
AvroRecord({"name": "abc"}), AvroRecord({"name": "def"}),
diff --git a/sdks/python/apache_beam/io/filebasedsource_test.py b/sdks/python/apache_beam/io/filebasedsource_test.py
index 2c5bd98..3c9adbd 100644
--- a/sdks/python/apache_beam/io/filebasedsource_test.py
+++ b/sdks/python/apache_beam/io/filebasedsource_test.py
@@ -436,11 +436,10 @@
self.assertCountEqual(expected_data, read_data)
def _run_source_test(self, pattern, expected_data, splittable=True):
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
- pattern, splittable=splittable))
- assert_that(pcoll, equal_to(expected_data))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
+ pattern, splittable=splittable))
+ assert_that(pcoll, equal_to(expected_data))
def test_source_file(self):
file_name, expected_data = write_data(100)
@@ -476,13 +475,12 @@
with bz2.BZ2File(filename, 'wb') as f:
f.write(b'\n'.join(lines))
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
- filename,
- splittable=False,
- compression_type=CompressionTypes.BZIP2))
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
+ filename,
+ splittable=False,
+ compression_type=CompressionTypes.BZIP2))
+ assert_that(pcoll, equal_to(lines))
def test_read_file_gzip(self):
_, lines = write_data(10)
@@ -491,13 +489,12 @@
with gzip.GzipFile(filename, 'wb') as f:
f.write(b'\n'.join(lines))
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
- filename,
- splittable=False,
- compression_type=CompressionTypes.GZIP))
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
+ filename,
+ splittable=False,
+ compression_type=CompressionTypes.GZIP))
+ assert_that(pcoll, equal_to(lines))
def test_read_pattern_bzip2(self):
_, lines = write_data(200)
@@ -509,13 +506,12 @@
compressed_chunks.append(
compressobj.compress(b'\n'.join(c)) + compressobj.flush())
file_pattern = write_prepared_pattern(compressed_chunks)
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
- file_pattern,
- splittable=False,
- compression_type=CompressionTypes.BZIP2))
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
+ file_pattern,
+ splittable=False,
+ compression_type=CompressionTypes.BZIP2))
+ assert_that(pcoll, equal_to(lines))
def test_read_pattern_gzip(self):
_, lines = write_data(200)
@@ -528,13 +524,12 @@
f.write(b'\n'.join(c))
compressed_chunks.append(out.getvalue())
file_pattern = write_prepared_pattern(compressed_chunks)
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
- file_pattern,
- splittable=False,
- compression_type=CompressionTypes.GZIP))
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
+ file_pattern,
+ splittable=False,
+ compression_type=CompressionTypes.GZIP))
+ assert_that(pcoll, equal_to(lines))
def test_read_auto_single_file_bzip2(self):
_, lines = write_data(10)
@@ -543,12 +538,11 @@
with bz2.BZ2File(filename, 'wb') as f:
f.write(b'\n'.join(lines))
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
- filename,
- compression_type=CompressionTypes.AUTO))
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
+ filename,
+ compression_type=CompressionTypes.AUTO))
+ assert_that(pcoll, equal_to(lines))
def test_read_auto_single_file_gzip(self):
_, lines = write_data(10)
@@ -557,12 +551,11 @@
with gzip.GzipFile(filename, 'wb') as f:
f.write(b'\n'.join(lines))
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
- filename,
- compression_type=CompressionTypes.AUTO))
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
+ filename,
+ compression_type=CompressionTypes.AUTO))
+ assert_that(pcoll, equal_to(lines))
def test_read_auto_pattern(self):
_, lines = write_data(200)
@@ -576,12 +569,11 @@
compressed_chunks.append(out.getvalue())
file_pattern = write_prepared_pattern(
compressed_chunks, suffixes=['.gz']*len(chunks))
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
- file_pattern,
- compression_type=CompressionTypes.AUTO))
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
+ file_pattern,
+ compression_type=CompressionTypes.AUTO))
+ assert_that(pcoll, equal_to(lines))
def test_read_auto_pattern_compressed_and_uncompressed(self):
_, lines = write_data(200)
@@ -598,12 +590,11 @@
chunks_to_write.append(b'\n'.join(c))
file_pattern = write_prepared_pattern(chunks_to_write,
suffixes=(['.gz', '']*3))
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
- file_pattern,
- compression_type=CompressionTypes.AUTO))
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> beam.io.Read(LineSource(
+ file_pattern,
+ compression_type=CompressionTypes.AUTO))
+ assert_that(pcoll, equal_to(lines))
def test_splits_get_coder_from_fbs(self):
class DummyCoder(object):
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_io_read_pipeline.py b/sdks/python/apache_beam/io/gcp/bigquery_io_read_pipeline.py
index dc6223f..7e1dc94 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_io_read_pipeline.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_io_read_pipeline.py
@@ -68,24 +68,21 @@
known_args, pipeline_args = parser.parse_known_args(argv)
options = PipelineOptions(pipeline_args)
- p = TestPipeline(options=options)
+ with TestPipeline(options=options) as p:
+ if known_args.beam_bq_source:
+ reader = _ReadFromBigQuery(
+ table='%s:%s' % (options.view_as(GoogleCloudOptions).project,
+ known_args.input_table))
+ else:
+ reader = beam.io.Read(beam.io.BigQuerySource(known_args.input_table))
- if known_args.beam_bq_source:
- reader = _ReadFromBigQuery(
- table='%s:%s' % (options.view_as(GoogleCloudOptions).project,
- known_args.input_table))
- else:
- reader = beam.io.Read(beam.io.BigQuerySource(known_args.input_table))
+ # pylint: disable=expression-not-assigned
+ count = (p | 'read' >> reader
+ | 'row to string' >> beam.ParDo(RowToStringWithSlowDown(),
+ num_slow=known_args.num_slow)
+ | 'count' >> beam.combiners.Count.Globally())
- # pylint: disable=expression-not-assigned
- count = (p | 'read' >> reader
- | 'row to string' >> beam.ParDo(RowToStringWithSlowDown(),
- num_slow=known_args.num_slow)
- | 'count' >> beam.combiners.Count.Globally())
-
- assert_that(count, equal_to([known_args.num_records]))
-
- p.run()
+ assert_that(count, equal_to([known_args.num_records]))
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_read_perf_test.py b/sdks/python/apache_beam/io/gcp/bigquery_read_perf_test.py
index 5d1fc98..d93cd88 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_read_perf_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_read_perf_test.py
@@ -117,17 +117,17 @@
# of the part
return {'data': base64.b64encode(record[1])}
- p = TestPipeline()
- # pylint: disable=expression-not-assigned
- (p
- | 'Produce rows' >> Read(SyntheticSource(self.parseTestPipelineOptions()))
- | 'Format' >> Map(format_record)
- | 'Write to BigQuery' >> WriteToBigQuery(
- dataset=self.input_dataset, table=self.input_table,
- schema=SCHEMA,
- create_disposition=BigQueryDisposition.CREATE_IF_NEEDED,
- write_disposition=BigQueryDisposition.WRITE_EMPTY))
- p.run().wait_until_finish()
+ with TestPipeline() as p:
+ # pylint: disable=expression-not-assigned
+ (p
+ | 'Produce rows' >> Read(SyntheticSource(
+ self.parseTestPipelineOptions()))
+ | 'Format' >> Map(format_record)
+ | 'Write to BigQuery' >> WriteToBigQuery(
+ dataset=self.input_dataset, table=self.input_table,
+ schema=SCHEMA,
+ create_disposition=BigQueryDisposition.CREATE_IF_NEEDED,
+ write_disposition=BigQueryDisposition.WRITE_EMPTY))
def test(self):
output = (self.pipeline
diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py
index 8912a3f..a4b9e62 100644
--- a/sdks/python/apache_beam/io/gcp/pubsub_test.py
+++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py
@@ -354,12 +354,11 @@
options = PipelineOptions([])
options.view_as(StandardOptions).streaming = True
- p = TestPipeline(options=options)
- pcoll = (p
- | ReadFromPubSub('projects/fakeprj/topics/a_topic',
- None, None, with_attributes=True))
- assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
- p.run()
+ with TestPipeline(options=options) as p:
+ pcoll = (p
+ | ReadFromPubSub('projects/fakeprj/topics/a_topic',
+ None, None, with_attributes=True))
+ assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
mock_pubsub.return_value.acknowledge.assert_has_calls([
mock.call(mock.ANY, [ack_id])])
@@ -378,12 +377,11 @@
options = PipelineOptions([])
options.view_as(StandardOptions).streaming = True
- p = TestPipeline(options=options)
- pcoll = (p
- | ReadStringsFromPubSub('projects/fakeprj/topics/a_topic',
- None, None))
- assert_that(pcoll, equal_to(expected_elements))
- p.run()
+ with TestPipeline(options=options) as p:
+ pcoll = (p
+ | ReadStringsFromPubSub('projects/fakeprj/topics/a_topic',
+ None, None))
+ assert_that(pcoll, equal_to(expected_elements))
mock_pubsub.return_value.acknowledge.assert_has_calls([
mock.call(mock.ANY, [ack_id])])
@@ -400,11 +398,10 @@
options = PipelineOptions([])
options.view_as(StandardOptions).streaming = True
- p = TestPipeline(options=options)
- pcoll = (p
- | ReadFromPubSub('projects/fakeprj/topics/a_topic', None, None))
- assert_that(pcoll, equal_to(expected_elements))
- p.run()
+ with TestPipeline(options=options) as p:
+ pcoll = (p
+ | ReadFromPubSub('projects/fakeprj/topics/a_topic', None, None))
+ assert_that(pcoll, equal_to(expected_elements))
mock_pubsub.return_value.acknowledge.assert_has_calls([
mock.call(mock.ANY, [ack_id])])
@@ -431,13 +428,12 @@
options = PipelineOptions([])
options.view_as(StandardOptions).streaming = True
- p = TestPipeline(options=options)
- pcoll = (p
- | ReadFromPubSub(
- 'projects/fakeprj/topics/a_topic', None, None,
- with_attributes=True, timestamp_attribute='time'))
- assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
- p.run()
+ with TestPipeline(options=options) as p:
+ pcoll = (p
+ | ReadFromPubSub(
+ 'projects/fakeprj/topics/a_topic', None, None,
+ with_attributes=True, timestamp_attribute='time'))
+ assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
mock_pubsub.return_value.acknowledge.assert_has_calls([
mock.call(mock.ANY, [ack_id])])
@@ -464,13 +460,12 @@
options = PipelineOptions([])
options.view_as(StandardOptions).streaming = True
- p = TestPipeline(options=options)
- pcoll = (p
- | ReadFromPubSub(
- 'projects/fakeprj/topics/a_topic', None, None,
- with_attributes=True, timestamp_attribute='time'))
- assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
- p.run()
+ with TestPipeline(options=options) as p:
+ pcoll = (p
+ | ReadFromPubSub(
+ 'projects/fakeprj/topics/a_topic', None, None,
+ with_attributes=True, timestamp_attribute='time'))
+ assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
mock_pubsub.return_value.acknowledge.assert_has_calls([
mock.call(mock.ANY, [ack_id])])
@@ -498,13 +493,12 @@
options = PipelineOptions([])
options.view_as(StandardOptions).streaming = True
- p = TestPipeline(options=options)
- pcoll = (p
- | ReadFromPubSub(
- 'projects/fakeprj/topics/a_topic', None, None,
- with_attributes=True, timestamp_attribute='nonexistent'))
- assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
- p.run()
+ with TestPipeline(options=options) as p:
+ pcoll = (p
+ | ReadFromPubSub(
+ 'projects/fakeprj/topics/a_topic', None, None,
+ with_attributes=True, timestamp_attribute='nonexistent'))
+ assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
mock_pubsub.return_value.acknowledge.assert_has_calls([
mock.call(mock.ANY, [ack_id])])
@@ -541,11 +535,11 @@
# id_label is unsupported in DirectRunner.
options = PipelineOptions([])
options.view_as(StandardOptions).streaming = True
- p = TestPipeline(options=options)
- _ = (p | ReadFromPubSub('projects/fakeprj/topics/a_topic', None, 'a_label'))
with self.assertRaisesRegex(NotImplementedError,
r'id_label is not supported'):
- p.run()
+ with TestPipeline(options=options) as p:
+ _ = (p | ReadFromPubSub(
+ 'projects/fakeprj/topics/a_topic', None, 'a_label'))
@unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
@@ -558,12 +552,11 @@
options = PipelineOptions([])
options.view_as(StandardOptions).streaming = True
- p = TestPipeline(options=options)
- _ = (p
- | Create(payloads)
- | WriteToPubSub('projects/fakeprj/topics/a_topic',
- with_attributes=False))
- p.run()
+ with TestPipeline(options=options) as p:
+ _ = (p
+ | Create(payloads)
+ | WriteToPubSub('projects/fakeprj/topics/a_topic',
+ with_attributes=False))
mock_pubsub.return_value.publish.assert_has_calls([
mock.call(mock.ANY, data)])
@@ -573,11 +566,10 @@
options = PipelineOptions([])
options.view_as(StandardOptions).streaming = True
- p = TestPipeline(options=options)
- _ = (p
- | Create(payloads)
- | WriteStringsToPubSub('projects/fakeprj/topics/a_topic'))
- p.run()
+ with TestPipeline(options=options) as p:
+ _ = (p
+ | Create(payloads)
+ | WriteStringsToPubSub('projects/fakeprj/topics/a_topic'))
mock_pubsub.return_value.publish.assert_has_calls([
mock.call(mock.ANY, data)])
@@ -588,12 +580,11 @@
options = PipelineOptions([])
options.view_as(StandardOptions).streaming = True
- p = TestPipeline(options=options)
- _ = (p
- | Create(payloads)
- | WriteToPubSub('projects/fakeprj/topics/a_topic',
- with_attributes=True))
- p.run()
+ with TestPipeline(options=options) as p:
+ _ = (p
+ | Create(payloads)
+ | WriteToPubSub('projects/fakeprj/topics/a_topic',
+ with_attributes=True))
mock_pubsub.return_value.publish.assert_has_calls([
mock.call(mock.ANY, data, **attributes)])
@@ -604,14 +595,13 @@
options = PipelineOptions([])
options.view_as(StandardOptions).streaming = True
- p = TestPipeline(options=options)
- _ = (p
- | Create(payloads)
- | WriteToPubSub('projects/fakeprj/topics/a_topic',
- with_attributes=True))
with self.assertRaisesRegex(AttributeError,
r'str.*has no attribute.*data'):
- p.run()
+ with TestPipeline(options=options) as p:
+ _ = (p
+ | Create(payloads)
+ | WriteToPubSub('projects/fakeprj/topics/a_topic',
+ with_attributes=True))
def test_write_messages_unsupported_features(self, mock_pubsub):
data = b'data'
@@ -620,24 +610,23 @@
options = PipelineOptions([])
options.view_as(StandardOptions).streaming = True
- p = TestPipeline(options=options)
- _ = (p
- | Create(payloads)
- | WriteToPubSub('projects/fakeprj/topics/a_topic',
- id_label='a_label'))
with self.assertRaisesRegex(NotImplementedError,
r'id_label is not supported'):
- p.run()
+ with TestPipeline(options=options) as p:
+ _ = (p
+ | Create(payloads)
+ | WriteToPubSub('projects/fakeprj/topics/a_topic',
+ id_label='a_label'))
+
options = PipelineOptions([])
options.view_as(StandardOptions).streaming = True
- p = TestPipeline(options=options)
- _ = (p
- | Create(payloads)
- | WriteToPubSub('projects/fakeprj/topics/a_topic',
- timestamp_attribute='timestamp'))
with self.assertRaisesRegex(NotImplementedError,
r'timestamp_attribute is not supported'):
- p.run()
+ with TestPipeline(options=options) as p:
+ _ = (p
+ | Create(payloads)
+ | WriteToPubSub('projects/fakeprj/topics/a_topic',
+ timestamp_attribute='timestamp'))
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/io/parquetio_it_test.py b/sdks/python/apache_beam/io/parquetio_it_test.py
index cc85cd2..ecc1957 100644
--- a/sdks/python/apache_beam/io/parquetio_it_test.py
+++ b/sdks/python/apache_beam/io/parquetio_it_test.py
@@ -68,12 +68,10 @@
file_prefix = "parquet_it_test"
init_size = 10
data_size = 20000
- p = TestPipeline(is_integration_test=True)
- pcol = self._generate_data(
- p, file_prefix, init_size, data_size)
- self._verify_data(pcol, init_size, data_size)
- result = p.run()
- result.wait_until_finish()
+ with TestPipeline(is_integration_test=True) as p:
+ pcol = self._generate_data(
+ p, file_prefix, init_size, data_size)
+ self._verify_data(pcol, init_size, data_size)
@staticmethod
def _sum_verifier(init_size, data_size, x):
diff --git a/sdks/python/apache_beam/io/sources_test.py b/sdks/python/apache_beam/io/sources_test.py
index e210a5b..fb34a98 100644
--- a/sdks/python/apache_beam/io/sources_test.py
+++ b/sdks/python/apache_beam/io/sources_test.py
@@ -125,11 +125,10 @@
def test_run_direct(self):
file_name = self._create_temp_file(b'aaaa\nbbbb\ncccc\ndddd')
- pipeline = TestPipeline()
- pcoll = pipeline | beam.io.Read(LineSource(file_name))
- assert_that(pcoll, equal_to([b'aaaa', b'bbbb', b'cccc', b'dddd']))
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | beam.io.Read(LineSource(file_name))
+ assert_that(pcoll, equal_to([b'aaaa', b'bbbb', b'cccc', b'dddd']))
- pipeline.run()
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/io/textio_test.py b/sdks/python/apache_beam/io/textio_test.py
index 0761e13..32765f0 100644
--- a/sdks/python/apache_beam/io/textio_test.py
+++ b/sdks/python/apache_beam/io/textio_test.py
@@ -424,28 +424,25 @@
def test_read_from_text_single_file(self):
file_name, expected_data = write_data(5)
assert len(expected_data) == 5
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromText(file_name)
- assert_that(pcoll, equal_to(expected_data))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromText(file_name)
+ assert_that(pcoll, equal_to(expected_data))
def test_read_from_text_with_file_name_single_file(self):
file_name, data = write_data(5)
expected_data = [(file_name, el) for el in data]
assert len(expected_data) == 5
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromTextWithFilename(file_name)
- assert_that(pcoll, equal_to(expected_data))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromTextWithFilename(file_name)
+ assert_that(pcoll, equal_to(expected_data))
def test_read_all_single_file(self):
file_name, expected_data = write_data(5)
assert len(expected_data) == 5
- pipeline = TestPipeline()
- pcoll = pipeline | 'Create' >> Create(
- [file_name]) |'ReadAll' >> ReadAllFromText()
- assert_that(pcoll, equal_to(expected_data))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Create' >> Create(
+ [file_name]) |'ReadAll' >> ReadAllFromText()
+ assert_that(pcoll, equal_to(expected_data))
def test_read_all_many_single_files(self):
file_name1, expected_data1 = write_data(5)
@@ -458,11 +455,10 @@
expected_data.extend(expected_data1)
expected_data.extend(expected_data2)
expected_data.extend(expected_data3)
- pipeline = TestPipeline()
- pcoll = pipeline | 'Create' >> Create(
- [file_name1, file_name2, file_name3]) |'ReadAll' >> ReadAllFromText()
- assert_that(pcoll, equal_to(expected_data))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Create' >> Create(
+ [file_name1, file_name2, file_name3]) |'ReadAll' >> ReadAllFromText()
+ assert_that(pcoll, equal_to(expected_data))
def test_read_all_unavailable_files_ignored(self):
file_name1, expected_data1 = write_data(5)
@@ -476,13 +472,12 @@
expected_data.extend(expected_data1)
expected_data.extend(expected_data2)
expected_data.extend(expected_data3)
- pipeline = TestPipeline()
- pcoll = (pipeline
- | 'Create' >> Create(
- [file_name1, file_name2, file_name3, file_name4])
- |'ReadAll' >> ReadAllFromText())
- assert_that(pcoll, equal_to(expected_data))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = (pipeline
+ | 'Create' >> Create(
+ [file_name1, file_name2, file_name3, file_name4])
+ |'ReadAll' >> ReadAllFromText())
+ assert_that(pcoll, equal_to(expected_data))
def test_read_from_text_single_file_with_coder(self):
class DummyCoder(coders.Coder):
@@ -494,37 +489,33 @@
file_name, expected_data = write_data(5)
assert len(expected_data) == 5
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromText(file_name, coder=DummyCoder())
- assert_that(pcoll, equal_to([record * 2 for record in expected_data]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromText(file_name, coder=DummyCoder())
+ assert_that(pcoll, equal_to([record * 2 for record in expected_data]))
def test_read_from_text_file_pattern(self):
pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4])
assert len(expected_data) == 40
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromText(pattern)
- assert_that(pcoll, equal_to(expected_data))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromText(pattern)
+ assert_that(pcoll, equal_to(expected_data))
def test_read_from_text_with_file_name_file_pattern(self):
pattern, expected_data = write_pattern(
lines_per_file=[5, 5], return_filenames=True)
assert len(expected_data) == 10
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromTextWithFilename(pattern)
- assert_that(pcoll, equal_to(expected_data))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromTextWithFilename(pattern)
+ assert_that(pcoll, equal_to(expected_data))
def test_read_all_file_pattern(self):
pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4])
assert len(expected_data) == 40
- pipeline = TestPipeline()
- pcoll = (pipeline
- | 'Create' >> Create([pattern])
- |'ReadAll' >> ReadAllFromText())
- assert_that(pcoll, equal_to(expected_data))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = (pipeline
+ | 'Create' >> Create([pattern])
+ |'ReadAll' >> ReadAllFromText())
+ assert_that(pcoll, equal_to(expected_data))
def test_read_all_many_file_patterns(self):
pattern1, expected_data1 = write_pattern([5, 3, 12, 8, 8, 4])
@@ -537,11 +528,10 @@
expected_data.extend(expected_data1)
expected_data.extend(expected_data2)
expected_data.extend(expected_data3)
- pipeline = TestPipeline()
- pcoll = pipeline | 'Create' >> Create(
- [pattern1, pattern2, pattern3]) |'ReadAll' >> ReadAllFromText()
- assert_that(pcoll, equal_to(expected_data))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Create' >> Create(
+ [pattern1, pattern2, pattern3]) |'ReadAll' >> ReadAllFromText()
+ assert_that(pcoll, equal_to(expected_data))
def test_read_auto_bzip2(self):
_, lines = write_data(15)
@@ -550,10 +540,9 @@
with bz2.BZ2File(file_name, 'wb') as f:
f.write('\n'.join(lines).encode('utf-8'))
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromText(file_name)
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromText(file_name)
+ assert_that(pcoll, equal_to(lines))
def test_read_auto_deflate(self):
_, lines = write_data(15)
@@ -562,10 +551,9 @@
with open(file_name, 'wb') as f:
f.write(zlib.compress('\n'.join(lines).encode('utf-8')))
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromText(file_name)
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromText(file_name)
+ assert_that(pcoll, equal_to(lines))
def test_read_auto_gzip(self):
_, lines = write_data(15)
@@ -575,10 +563,9 @@
with gzip.GzipFile(file_name, 'wb') as f:
f.write('\n'.join(lines).encode('utf-8'))
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromText(file_name)
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromText(file_name)
+ assert_that(pcoll, equal_to(lines))
def test_read_bzip2(self):
_, lines = write_data(15)
@@ -587,12 +574,11 @@
with bz2.BZ2File(file_name, 'wb') as f:
f.write('\n'.join(lines).encode('utf-8'))
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromText(
- file_name,
- compression_type=CompressionTypes.BZIP2)
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromText(
+ file_name,
+ compression_type=CompressionTypes.BZIP2)
+ assert_that(pcoll, equal_to(lines))
def test_read_corrupted_bzip2_fails(self):
_, lines = write_data(15)
@@ -604,13 +590,12 @@
with open(file_name, 'wb') as f:
f.write(b'corrupt')
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromText(
- file_name,
- compression_type=CompressionTypes.BZIP2)
- assert_that(pcoll, equal_to(lines))
with self.assertRaises(Exception):
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromText(
+ file_name,
+ compression_type=CompressionTypes.BZIP2)
+ assert_that(pcoll, equal_to(lines))
def test_read_bzip2_concat(self):
with TempDir() as tempdir:
@@ -645,14 +630,13 @@
final_bzip2_file, 'ab') as dst:
dst.writelines(src.readlines())
- pipeline = TestPipeline()
- lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText(
- final_bzip2_file,
- compression_type=beam.io.filesystem.CompressionTypes.BZIP2)
+ with TestPipeline() as pipeline:
+ lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText(
+ final_bzip2_file,
+ compression_type=beam.io.filesystem.CompressionTypes.BZIP2)
- expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']
- assert_that(lines, equal_to(expected))
- pipeline.run()
+ expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']
+ assert_that(lines, equal_to(expected))
def test_read_deflate(self):
_, lines = write_data(15)
@@ -661,13 +645,12 @@
with open(file_name, 'wb') as f:
f.write(zlib.compress('\n'.join(lines).encode('utf-8')))
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromText(
- file_name,
- 0, CompressionTypes.DEFLATE,
- True, coders.StrUtf8Coder())
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromText(
+ file_name,
+ 0, CompressionTypes.DEFLATE,
+ True, coders.StrUtf8Coder())
+ assert_that(pcoll, equal_to(lines))
def test_read_corrupted_deflate_fails(self):
_, lines = write_data(15)
@@ -679,15 +662,13 @@
with open(file_name, 'wb') as f:
f.write(b'corrupt')
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromText(
- file_name,
- 0, CompressionTypes.DEFLATE,
- True, coders.StrUtf8Coder())
- assert_that(pcoll, equal_to(lines))
-
with self.assertRaises(Exception):
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromText(
+ file_name,
+ 0, CompressionTypes.DEFLATE,
+ True, coders.StrUtf8Coder())
+ assert_that(pcoll, equal_to(lines))
def test_read_deflate_concat(self):
with TempDir() as tempdir:
@@ -722,13 +703,13 @@
open(final_deflate_file, 'ab') as dst:
dst.writelines(src.readlines())
- pipeline = TestPipeline()
- lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText(
- final_deflate_file,
- compression_type=beam.io.filesystem.CompressionTypes.DEFLATE)
+ with TestPipeline() as pipeline:
+ lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText(
+ final_deflate_file,
+ compression_type=beam.io.filesystem.CompressionTypes.DEFLATE)
- expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']
- assert_that(lines, equal_to(expected))
+ expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']
+ assert_that(lines, equal_to(expected))
def test_read_gzip(self):
_, lines = write_data(15)
@@ -737,13 +718,12 @@
with gzip.GzipFile(file_name, 'wb') as f:
f.write('\n'.join(lines).encode('utf-8'))
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromText(
- file_name,
- 0, CompressionTypes.GZIP,
- True, coders.StrUtf8Coder())
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromText(
+ file_name,
+ 0, CompressionTypes.GZIP,
+ True, coders.StrUtf8Coder())
+ assert_that(pcoll, equal_to(lines))
def test_read_corrupted_gzip_fails(self):
_, lines = write_data(15)
@@ -755,15 +735,13 @@
with open(file_name, 'wb') as f:
f.write(b'corrupt')
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromText(
- file_name,
- 0, CompressionTypes.GZIP,
- True, coders.StrUtf8Coder())
- assert_that(pcoll, equal_to(lines))
-
with self.assertRaises(Exception):
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromText(
+ file_name,
+ 0, CompressionTypes.GZIP,
+ True, coders.StrUtf8Coder())
+ assert_that(pcoll, equal_to(lines))
def test_read_gzip_concat(self):
with TempDir() as tempdir:
@@ -798,13 +776,13 @@
open(final_gzip_file, 'ab') as dst:
dst.writelines(src.readlines())
- pipeline = TestPipeline()
- lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText(
- final_gzip_file,
- compression_type=beam.io.filesystem.CompressionTypes.GZIP)
+ with TestPipeline() as pipeline:
+ lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText(
+ final_gzip_file,
+ compression_type=beam.io.filesystem.CompressionTypes.GZIP)
- expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']
- assert_that(lines, equal_to(expected))
+ expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']
+ assert_that(lines, equal_to(expected))
def test_read_all_gzip(self):
_, lines = write_data(100)
@@ -812,13 +790,12 @@
file_name = tempdir.create_temp_file()
with gzip.GzipFile(file_name, 'wb') as f:
f.write('\n'.join(lines).encode('utf-8'))
- pipeline = TestPipeline()
- pcoll = (pipeline
- | Create([file_name])
- | 'ReadAll' >> ReadAllFromText(
- compression_type=CompressionTypes.GZIP))
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = (pipeline
+ | Create([file_name])
+ | 'ReadAll' >> ReadAllFromText(
+ compression_type=CompressionTypes.GZIP))
+ assert_that(pcoll, equal_to(lines))
def test_read_gzip_large(self):
_, lines = write_data(10000)
@@ -828,13 +805,12 @@
with gzip.GzipFile(file_name, 'wb') as f:
f.write('\n'.join(lines).encode('utf-8'))
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromText(
- file_name,
- 0, CompressionTypes.GZIP,
- True, coders.StrUtf8Coder())
- assert_that(pcoll, equal_to(lines))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromText(
+ file_name,
+ 0, CompressionTypes.GZIP,
+ True, coders.StrUtf8Coder())
+ assert_that(pcoll, equal_to(lines))
def test_read_gzip_large_after_splitting(self):
_, lines = write_data(10000)
@@ -861,13 +837,12 @@
def test_read_gzip_empty_file(self):
with TempDir() as tempdir:
file_name = tempdir.create_temp_file()
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromText(
- file_name,
- 0, CompressionTypes.GZIP,
- True, coders.StrUtf8Coder())
- assert_that(pcoll, equal_to([]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromText(
+ file_name,
+ 0, CompressionTypes.GZIP,
+ True, coders.StrUtf8Coder())
+ assert_that(pcoll, equal_to([]))
def _remove_lines(self, lines, sublist_lengths, num_to_remove):
"""Utility function to remove num_to_remove lines from each sublist.
@@ -950,12 +925,11 @@
with gzip.GzipFile(file_name, 'wb') as f:
f.write('\n'.join(lines).encode('utf-8'))
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromText(
- file_name, 0, CompressionTypes.GZIP,
- True, coders.StrUtf8Coder(), skip_header_lines=2)
- assert_that(pcoll, equal_to(lines[2:]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromText(
+ file_name, 0, CompressionTypes.GZIP,
+ True, coders.StrUtf8Coder(), skip_header_lines=2)
+ assert_that(pcoll, equal_to(lines[2:]))
def test_read_after_splitting_skip_header(self):
file_name, expected_data = write_data(100)
@@ -1105,10 +1079,9 @@
self.assertEqual(f.read().splitlines(), header.splitlines())
def test_write_dataflow(self):
- pipeline = TestPipeline()
- pcoll = pipeline | beam.core.Create(self.lines)
- pcoll | 'Write' >> WriteToText(self.path) # pylint: disable=expression-not-assigned
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | beam.core.Create(self.lines)
+ pcoll | 'Write' >> WriteToText(self.path) # pylint: disable=expression-not-assigned
read_result = []
for file_name in glob.glob(self.path + '*'):
@@ -1118,10 +1091,9 @@
self.assertEqual(sorted(read_result), sorted(self.lines))
def test_write_dataflow_auto_compression(self):
- pipeline = TestPipeline()
- pcoll = pipeline | beam.core.Create(self.lines)
- pcoll | 'Write' >> WriteToText(self.path, file_name_suffix='.gz') # pylint: disable=expression-not-assigned
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | beam.core.Create(self.lines)
+ pcoll | 'Write' >> WriteToText(self.path, file_name_suffix='.gz') # pylint: disable=expression-not-assigned
read_result = []
for file_name in glob.glob(self.path + '*'):
@@ -1131,13 +1103,12 @@
self.assertEqual(sorted(read_result), sorted(self.lines))
def test_write_dataflow_auto_compression_unsharded(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'Create' >> beam.core.Create(self.lines)
- pcoll | 'Write' >> WriteToText( # pylint: disable=expression-not-assigned
- self.path + '.gz',
- shard_name_template='')
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Create' >> beam.core.Create(self.lines)
+ pcoll | 'Write' >> WriteToText( # pylint: disable=expression-not-assigned
+ self.path + '.gz',
+ shard_name_template='')
- pipeline.run()
read_result = []
for file_name in glob.glob(self.path + '*'):
@@ -1147,14 +1118,13 @@
self.assertEqual(sorted(read_result), sorted(self.lines))
def test_write_dataflow_header(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'Create' >> beam.core.Create(self.lines)
- header_text = 'foo'
- pcoll | 'Write' >> WriteToText( # pylint: disable=expression-not-assigned
- self.path + '.gz',
- shard_name_template='',
- header=header_text)
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Create' >> beam.core.Create(self.lines)
+ header_text = 'foo'
+ pcoll | 'Write' >> WriteToText( # pylint: disable=expression-not-assigned
+ self.path + '.gz',
+ shard_name_template='',
+ header=header_text)
read_result = []
for file_name in glob.glob(self.path + '*'):
diff --git a/sdks/python/apache_beam/io/vcfio_test.py b/sdks/python/apache_beam/io/vcfio_test.py
index 0c820ab..8091ab6 100644
--- a/sdks/python/apache_beam/io/vcfio_test.py
+++ b/sdks/python/apache_beam/io/vcfio_test.py
@@ -496,26 +496,23 @@
with TempDir() as tempdir:
file_name = self._create_temp_vcf_file(_SAMPLE_HEADER_LINES +
_SAMPLE_TEXT_LINES, tempdir)
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromVcf(file_name)
- assert_that(pcoll, _count_equals_to(len(_SAMPLE_TEXT_LINES)))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromVcf(file_name)
+ assert_that(pcoll, _count_equals_to(len(_SAMPLE_TEXT_LINES)))
@unittest.skipIf(VCF_FILE_DIR_MISSING, 'VCF test file directory is missing')
def test_pipeline_read_single_file_large(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromVcf(
- get_full_file_path('valid-4.0.vcf'))
- assert_that(pcoll, _count_equals_to(5))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromVcf(
+ get_full_file_path('valid-4.0.vcf'))
+ assert_that(pcoll, _count_equals_to(5))
@unittest.skipIf(VCF_FILE_DIR_MISSING, 'VCF test file directory is missing')
def test_pipeline_read_file_pattern_large(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'Read' >> ReadFromVcf(
- os.path.join(get_full_dir(), 'valid-*.vcf'))
- assert_that(pcoll, _count_equals_to(9900))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Read' >> ReadFromVcf(
+ os.path.join(get_full_dir(), 'valid-*.vcf'))
+ assert_that(pcoll, _count_equals_to(9900))
def test_read_reentrant_without_splitting(self):
with TempDir() as tempdir:
diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py
index 48d3b0d..7415baf 100644
--- a/sdks/python/apache_beam/pipeline_test.py
+++ b/sdks/python/apache_beam/pipeline_test.py
@@ -154,86 +154,83 @@
self.leave_composite.append(transform_node)
def test_create(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'label1' >> Create([1, 2, 3])
- assert_that(pcoll, equal_to([1, 2, 3]))
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'label1' >> Create([1, 2, 3])
+ assert_that(pcoll, equal_to([1, 2, 3]))
- # Test if initial value is an iterator object.
- pcoll2 = pipeline | 'label2' >> Create(iter((4, 5, 6)))
- pcoll3 = pcoll2 | 'do' >> FlatMap(lambda x: [x + 10])
- assert_that(pcoll3, equal_to([14, 15, 16]), label='pcoll3')
- pipeline.run()
+ # Test if initial value is an iterator object.
+ pcoll2 = pipeline | 'label2' >> Create(iter((4, 5, 6)))
+ pcoll3 = pcoll2 | 'do' >> FlatMap(lambda x: [x + 10])
+ assert_that(pcoll3, equal_to([14, 15, 16]), label='pcoll3')
def test_flatmap_builtin(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'label1' >> Create([1, 2, 3])
- assert_that(pcoll, equal_to([1, 2, 3]))
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'label1' >> Create([1, 2, 3])
+ assert_that(pcoll, equal_to([1, 2, 3]))
- pcoll2 = pcoll | 'do' >> FlatMap(lambda x: [x + 10])
- assert_that(pcoll2, equal_to([11, 12, 13]), label='pcoll2')
+ pcoll2 = pcoll | 'do' >> FlatMap(lambda x: [x + 10])
+ assert_that(pcoll2, equal_to([11, 12, 13]), label='pcoll2')
- pcoll3 = pcoll2 | 'm1' >> Map(lambda x: [x, 12])
- assert_that(pcoll3,
- equal_to([[11, 12], [12, 12], [13, 12]]), label='pcoll3')
+ pcoll3 = pcoll2 | 'm1' >> Map(lambda x: [x, 12])
+ assert_that(pcoll3,
+ equal_to([[11, 12], [12, 12], [13, 12]]), label='pcoll3')
- pcoll4 = pcoll3 | 'do2' >> FlatMap(set)
- assert_that(pcoll4, equal_to([11, 12, 12, 12, 13]), label='pcoll4')
- pipeline.run()
+ pcoll4 = pcoll3 | 'do2' >> FlatMap(set)
+ assert_that(pcoll4, equal_to([11, 12, 12, 12, 13]), label='pcoll4')
def test_maptuple_builtin(self):
- pipeline = TestPipeline()
- pcoll = pipeline | Create([('e1', 'e2')])
- side1 = beam.pvalue.AsSingleton(pipeline | 'side1' >> Create(['s1']))
- side2 = beam.pvalue.AsSingleton(pipeline | 'side2' >> Create(['s2']))
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | Create([('e1', 'e2')])
+ side1 = beam.pvalue.AsSingleton(pipeline | 'side1' >> Create(['s1']))
+ side2 = beam.pvalue.AsSingleton(pipeline | 'side2' >> Create(['s2']))
- # A test function with a tuple input, an auxiliary parameter,
- # and some side inputs.
- fn = lambda e1, e2, t=DoFn.TimestampParam, s1=None, s2=None: (
- e1, e2, t, s1, s2)
- assert_that(pcoll | 'NoSides' >> beam.core.MapTuple(fn),
- equal_to([('e1', 'e2', MIN_TIMESTAMP, None, None)]),
- label='NoSidesCheck')
- assert_that(pcoll | 'StaticSides' >> beam.core.MapTuple(fn, 's1', 's2'),
- equal_to([('e1', 'e2', MIN_TIMESTAMP, 's1', 's2')]),
- label='StaticSidesCheck')
- assert_that(pcoll | 'DynamicSides' >> beam.core.MapTuple(fn, side1, side2),
- equal_to([('e1', 'e2', MIN_TIMESTAMP, 's1', 's2')]),
- label='DynamicSidesCheck')
- assert_that(pcoll | 'MixedSides' >> beam.core.MapTuple(fn, s2=side2),
- equal_to([('e1', 'e2', MIN_TIMESTAMP, None, 's2')]),
- label='MixedSidesCheck')
- pipeline.run()
+ # A test function with a tuple input, an auxiliary parameter,
+ # and some side inputs.
+ fn = lambda e1, e2, t=DoFn.TimestampParam, s1=None, s2=None: (
+ e1, e2, t, s1, s2)
+ assert_that(pcoll | 'NoSides' >> beam.core.MapTuple(fn),
+ equal_to([('e1', 'e2', MIN_TIMESTAMP, None, None)]),
+ label='NoSidesCheck')
+ assert_that(pcoll | 'StaticSides' >> beam.core.MapTuple(fn, 's1', 's2'),
+ equal_to([('e1', 'e2', MIN_TIMESTAMP, 's1', 's2')]),
+ label='StaticSidesCheck')
+ assert_that(pcoll | 'DynamicSides' >> beam.core.MapTuple(
+ fn, side1, side2),
+ equal_to([('e1', 'e2', MIN_TIMESTAMP, 's1', 's2')]),
+ label='DynamicSidesCheck')
+ assert_that(pcoll | 'MixedSides' >> beam.core.MapTuple(fn, s2=side2),
+ equal_to([('e1', 'e2', MIN_TIMESTAMP, None, 's2')]),
+ label='MixedSidesCheck')
def test_flatmaptuple_builtin(self):
- pipeline = TestPipeline()
- pcoll = pipeline | Create([('e1', 'e2')])
- side1 = beam.pvalue.AsSingleton(pipeline | 'side1' >> Create(['s1']))
- side2 = beam.pvalue.AsSingleton(pipeline | 'side2' >> Create(['s2']))
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | Create([('e1', 'e2')])
+ side1 = beam.pvalue.AsSingleton(pipeline | 'side1' >> Create(['s1']))
+ side2 = beam.pvalue.AsSingleton(pipeline | 'side2' >> Create(['s2']))
- # A test function with a tuple input, an auxiliary parameter,
- # and some side inputs.
- fn = lambda e1, e2, t=DoFn.TimestampParam, s1=None, s2=None: (
- e1, e2, t, s1, s2)
- assert_that(pcoll | 'NoSides' >> beam.core.FlatMapTuple(fn),
- equal_to(['e1', 'e2', MIN_TIMESTAMP, None, None]),
- label='NoSidesCheck')
- assert_that(pcoll | 'StaticSides' >> beam.core.FlatMapTuple(fn, 's1', 's2'),
- equal_to(['e1', 'e2', MIN_TIMESTAMP, 's1', 's2']),
- label='StaticSidesCheck')
- assert_that(pcoll
- | 'DynamicSides' >> beam.core.FlatMapTuple(fn, side1, side2),
- equal_to(['e1', 'e2', MIN_TIMESTAMP, 's1', 's2']),
- label='DynamicSidesCheck')
- assert_that(pcoll | 'MixedSides' >> beam.core.FlatMapTuple(fn, s2=side2),
- equal_to(['e1', 'e2', MIN_TIMESTAMP, None, 's2']),
- label='MixedSidesCheck')
- pipeline.run()
+ # A test function with a tuple input, an auxiliary parameter,
+ # and some side inputs.
+ fn = lambda e1, e2, t=DoFn.TimestampParam, s1=None, s2=None: (
+ e1, e2, t, s1, s2)
+ assert_that(pcoll | 'NoSides' >> beam.core.FlatMapTuple(fn),
+ equal_to(['e1', 'e2', MIN_TIMESTAMP, None, None]),
+ label='NoSidesCheck')
+ assert_that(pcoll | 'StaticSides' >> beam.core.FlatMapTuple(
+ fn, 's1', 's2'),
+ equal_to(['e1', 'e2', MIN_TIMESTAMP, 's1', 's2']),
+ label='StaticSidesCheck')
+ assert_that(pcoll
+ | 'DynamicSides' >> beam.core.FlatMapTuple(fn, side1, side2),
+ equal_to(['e1', 'e2', MIN_TIMESTAMP, 's1', 's2']),
+ label='DynamicSidesCheck')
+ assert_that(pcoll | 'MixedSides' >> beam.core.FlatMapTuple(fn, s2=side2),
+ equal_to(['e1', 'e2', MIN_TIMESTAMP, None, 's2']),
+ label='MixedSidesCheck')
def test_create_singleton_pcollection(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'label' >> Create([[1, 2, 3]])
- assert_that(pcoll, equal_to([[1, 2, 3]]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'label' >> Create([[1, 2, 3]])
+ assert_that(pcoll, equal_to([[1, 2, 3]]))
# TODO(BEAM-1555): Test is failing on the service, with FakeSource.
# @attr('ValidatesRunner')
@@ -249,10 +246,9 @@
self.assertEqual(outputs_counter.committed, 6)
def test_fake_read(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'read' >> Read(FakeSource([1, 2, 3]))
- assert_that(pcoll, equal_to([1, 2, 3]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'read' >> Read(FakeSource([1, 2, 3]))
+ assert_that(pcoll, equal_to([1, 2, 3]))
def test_visit_entire_graph(self):
pipeline = Pipeline()
@@ -274,11 +270,10 @@
self.assertEqual(visitor.leave_composite[0].transform, transform)
def test_apply_custom_transform(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'pcoll' >> Create([1, 2, 3])
- result = pcoll | PipelineTest.CustomTransform()
- assert_that(result, equal_to([2, 3, 4]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'pcoll' >> Create([1, 2, 3])
+ result = pcoll | PipelineTest.CustomTransform()
+ assert_that(result, equal_to([2, 3, 4]))
def test_reuse_custom_transform_instance(self):
pipeline = Pipeline()
@@ -295,15 +290,14 @@
'pvalue | "label" >> transform')
def test_reuse_cloned_custom_transform_instance(self):
- pipeline = TestPipeline()
- pcoll1 = pipeline | 'pc1' >> Create([1, 2, 3])
- pcoll2 = pipeline | 'pc2' >> Create([4, 5, 6])
- transform = PipelineTest.CustomTransform()
- result1 = pcoll1 | transform
- result2 = pcoll2 | 'new_label' >> transform
- assert_that(result1, equal_to([2, 3, 4]), label='r1')
- assert_that(result2, equal_to([5, 6, 7]), label='r2')
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll1 = pipeline | 'pc1' >> Create([1, 2, 3])
+ pcoll2 = pipeline | 'pc2' >> Create([4, 5, 6])
+ transform = PipelineTest.CustomTransform()
+ result1 = pcoll1 | transform
+ result2 = pcoll2 | 'new_label' >> transform
+ assert_that(result1, equal_to([2, 3, 4]), label='r1')
+ assert_that(result2, equal_to([5, 6, 7]), label='r2')
def test_transform_no_super_init(self):
class AddSuffix(PTransform):
@@ -347,24 +341,23 @@
# TODO(robertwb): reduce memory usage of FnApiRunner so that this test
# passes.
- pipeline = TestPipeline(runner='BundleBasedDirectRunner')
+ with TestPipeline(runner='BundleBasedDirectRunner') as pipeline:
- # Consumed memory should not be proportional to the number of maps.
- memory_threshold = (
- get_memory_usage_in_bytes() + (5 * len_elements * num_elements))
+ # Consumed memory should not be proportional to the number of maps.
+ memory_threshold = (
+ get_memory_usage_in_bytes() + (5 * len_elements * num_elements))
- # Plus small additional slack for memory fluctuations during the test.
- memory_threshold += 10 * (2 ** 20)
+ # Plus small additional slack for memory fluctuations during the test.
+ memory_threshold += 10 * (2 ** 20)
- biglist = pipeline | 'oom:create' >> Create(
- ['x' * len_elements] * num_elements)
- for i in range(num_maps):
- biglist = biglist | ('oom:addone-%d' % i) >> Map(lambda x: x + 'y')
- result = biglist | 'oom:check' >> Map(check_memory, memory_threshold)
- assert_that(result, equal_to(
- ['x' * len_elements + 'y' * num_maps] * num_elements))
+ biglist = pipeline | 'oom:create' >> Create(
+ ['x' * len_elements] * num_elements)
+ for i in range(num_maps):
+ biglist = biglist | ('oom:addone-%d' % i) >> Map(lambda x: x + 'y')
+ result = biglist | 'oom:check' >> Map(check_memory, memory_threshold)
+ assert_that(result, equal_to(
+ ['x' * len_elements + 'y' * num_maps] * num_elements))
- pipeline.run()
def test_aggregator_empty_input(self):
actual = [] | CombineGlobally(max).without_defaults()
@@ -473,27 +466,27 @@
else:
yield TaggedOutput('letters', x)
- p = TestPipeline()
- multi = (p
- | beam.Create([1, 2, 3, 'a', 'b', 'c'])
- | 'MyMultiOutput' >> beam.ParDo(mux_input).with_outputs())
- letters = multi.letters | 'MyLetters' >> beam.Map(lambda x: x)
- numbers = multi.numbers | 'MyNumbers' >> beam.Map(lambda x: x)
+ with TestPipeline() as p:
+ multi = (p
+ | beam.Create([1, 2, 3, 'a', 'b', 'c'])
+ | 'MyMultiOutput' >> beam.ParDo(mux_input).with_outputs())
+ letters = multi.letters | 'MyLetters' >> beam.Map(lambda x: x)
+ numbers = multi.numbers | 'MyNumbers' >> beam.Map(lambda x: x)
- # Assert that the PCollection replacement worked correctly and that elements
- # are flowing through. The replacement transform first multiples by 2 then
- # the leaf nodes inside the composite multiply by an additional 3 and 5. Use
- # prime numbers to ensure that each transform is getting executed once.
- assert_that(letters,
- equal_to(['a'*2*3, 'b'*2*3, 'c'*2*3]),
- label='assert letters')
- assert_that(numbers,
- equal_to([1*2*5, 2*2*5, 3*2*5]),
- label='assert numbers')
+ # Assert that the PCollection replacement worked correctly and that
+ # elements are flowing through. The replacement transform first
+ # multiples by 2 then the leaf nodes inside the composite multiply by
+ # an additional 3 and 5. Use prime numbers to ensure that each
+ # transform is getting executed once.
+ assert_that(letters,
+ equal_to(['a'*2*3, 'b'*2*3, 'c'*2*3]),
+ label='assert letters')
+ assert_that(numbers,
+ equal_to([1*2*5, 2*2*5, 3*2*5]),
+ label='assert numbers')
- # Do the replacement and run the element assertions.
- p.replace_all([MultiOutputOverride()])
- p.run()
+ # Do the replacement and run the element assertions.
+ p.replace_all([MultiOutputOverride()])
# The following checks the graph to make sure the replacement occurred.
visitor = PipelineTest.Visitor(visited=[])
@@ -535,20 +528,18 @@
def process(self, element, counter=DoFn.StateParam(BYTES_STATE)):
return self.return_recursive(1)
- p = TestPipeline()
- pcoll = (p
- | beam.Create([(1, 1), (2, 2), (3, 3)])
- | beam.GroupByKey()
- | beam.ParDo(StatefulDoFn()))
- p.run()
+ with TestPipeline() as p:
+ pcoll = (p
+ | beam.Create([(1, 1), (2, 2), (3, 3)])
+ | beam.GroupByKey()
+ | beam.ParDo(StatefulDoFn()))
self.assertEqual(pcoll.element_type, typehints.Any)
- p = TestPipeline()
- pcoll = (p
- | beam.Create([(1, 1), (2, 2), (3, 3)])
- | beam.GroupByKey()
- | beam.ParDo(StatefulDoFn()).with_output_types(str))
- p.run()
+ with TestPipeline() as p:
+ pcoll = (p
+ | beam.Create([(1, 1), (2, 2), (3, 3)])
+ | beam.GroupByKey()
+ | beam.ParDo(StatefulDoFn()).with_output_types(str))
self.assertEqual(pcoll.element_type, str)
def test_track_pcoll_unbounded(self):
@@ -609,40 +600,37 @@
def process(self, element):
yield element + 10
- pipeline = TestPipeline()
- pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn())
- assert_that(pcoll, equal_to([11, 12]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn())
+ assert_that(pcoll, equal_to([11, 12]))
def test_side_input_no_tag(self):
class TestDoFn(DoFn):
def process(self, element, prefix, suffix):
return ['%s-%s-%s' % (prefix, element, suffix)]
- pipeline = TestPipeline()
- words_list = ['aa', 'bb', 'cc']
- words = pipeline | 'SomeWords' >> Create(words_list)
- prefix = 'zyx'
- suffix = pipeline | 'SomeString' >> Create(['xyz']) # side in
- result = words | 'DecorateWordsDoFnNoTag' >> ParDo(
- TestDoFn(), prefix, suffix=AsSingleton(suffix))
- assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ words_list = ['aa', 'bb', 'cc']
+ words = pipeline | 'SomeWords' >> Create(words_list)
+ prefix = 'zyx'
+ suffix = pipeline | 'SomeString' >> Create(['xyz']) # side in
+ result = words | 'DecorateWordsDoFnNoTag' >> ParDo(
+ TestDoFn(), prefix, suffix=AsSingleton(suffix))
+ assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list]))
def test_side_input_tagged(self):
class TestDoFn(DoFn):
def process(self, element, prefix, suffix=DoFn.SideInputParam):
return ['%s-%s-%s' % (prefix, element, suffix)]
- pipeline = TestPipeline()
- words_list = ['aa', 'bb', 'cc']
- words = pipeline | 'SomeWords' >> Create(words_list)
- prefix = 'zyx'
- suffix = pipeline | 'SomeString' >> Create(['xyz']) # side in
- result = words | 'DecorateWordsDoFnNoTag' >> ParDo(
- TestDoFn(), prefix, suffix=AsSingleton(suffix))
- assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ words_list = ['aa', 'bb', 'cc']
+ words = pipeline | 'SomeWords' >> Create(words_list)
+ prefix = 'zyx'
+ suffix = pipeline | 'SomeString' >> Create(['xyz']) # side in
+ result = words | 'DecorateWordsDoFnNoTag' >> ParDo(
+ TestDoFn(), prefix, suffix=AsSingleton(suffix))
+ assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list]))
@attr('ValidatesRunner')
def test_element_param(self):
@@ -668,32 +656,30 @@
def process(self, element, window=DoFn.WindowParam):
yield (element, (float(window.start), float(window.end)))
- pipeline = TestPipeline()
- pcoll = (pipeline
- | Create([1, 7])
- | Map(lambda x: TimestampedValue(x, x))
- | WindowInto(windowfn=SlidingWindows(10, 5))
- | ParDo(TestDoFn()))
- assert_that(pcoll, equal_to([(1, (-5, 5)), (1, (0, 10)),
- (7, (0, 10)), (7, (5, 15))]))
- pcoll2 = pcoll | 'Again' >> ParDo(TestDoFn())
- assert_that(
- pcoll2,
- equal_to([
- ((1, (-5, 5)), (-5, 5)), ((1, (0, 10)), (0, 10)),
- ((7, (0, 10)), (0, 10)), ((7, (5, 15)), (5, 15))]),
- label='doubled windows')
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = (pipeline
+ | Create([1, 7])
+ | Map(lambda x: TimestampedValue(x, x))
+ | WindowInto(windowfn=SlidingWindows(10, 5))
+ | ParDo(TestDoFn()))
+ assert_that(pcoll, equal_to([(1, (-5, 5)), (1, (0, 10)),
+ (7, (0, 10)), (7, (5, 15))]))
+ pcoll2 = pcoll | 'Again' >> ParDo(TestDoFn())
+ assert_that(
+ pcoll2,
+ equal_to([
+ ((1, (-5, 5)), (-5, 5)), ((1, (0, 10)), (0, 10)),
+ ((7, (0, 10)), (0, 10)), ((7, (5, 15)), (5, 15))]),
+ label='doubled windows')
def test_timestamp_param(self):
class TestDoFn(DoFn):
def process(self, element, timestamp=DoFn.TimestampParam):
yield timestamp
- pipeline = TestPipeline()
- pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn())
- assert_that(pcoll, equal_to([MIN_TIMESTAMP, MIN_TIMESTAMP]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn())
+ assert_that(pcoll, equal_to([MIN_TIMESTAMP, MIN_TIMESTAMP]))
def test_timestamp_param_map(self):
with TestPipeline() as p:
@@ -733,12 +719,11 @@
# Ensure that we don't use default values in a context where they must be
# comparable (see BEAM-8301).
- pipeline = TestPipeline()
- pcoll = (pipeline
- | beam.Create([None])
- | Map(lambda e, x=IncomparableType(): (e, type(x).__name__)))
- assert_that(pcoll, equal_to([(None, 'IncomparableType')]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = (pipeline
+ | beam.Create([None])
+ | Map(lambda e, x=IncomparableType(): (e, type(x).__name__)))
+ assert_that(pcoll, equal_to([(None, 'IncomparableType')]))
class Bacon(PipelineOptions):
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
index db35635..69b1fb8 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
@@ -500,7 +500,9 @@
test_options = options.view_as(TestOptions)
# If it is a dry run, return without submitting the job.
if test_options.dry_run:
- return None
+ result = PipelineResult(PipelineState.DONE)
+ result.wait_until_finish = lambda duration=None: None
+ return result
# Get a Dataflow API client and set its options
self.dataflow_client = apiclient.DataflowApplicationClient(options)
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
index 74d6e57..d00066c 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
@@ -210,12 +210,12 @@
self.default_properties.append('--experiments=beam_fn_api')
self.default_properties.append('--worker_harness_container_image=FOO')
remote_runner = DataflowRunner()
- p = Pipeline(remote_runner,
- options=PipelineOptions(self.default_properties))
- (p | ptransform.Create([1, 2, 3]) # pylint: disable=expression-not-assigned
- | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
- | ptransform.GroupByKey())
- p.run()
+ with Pipeline(
+ remote_runner,
+ options=PipelineOptions(self.default_properties)) as p:
+ (p | ptransform.Create([1, 2, 3]) # pylint: disable=expression-not-assigned
+ | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
+ | ptransform.GroupByKey())
self.assertEqual(
list(remote_runner.proto_pipeline.components.environments.values()),
[beam_runner_api_pb2.Environment(
@@ -225,20 +225,19 @@
def test_remote_runner_translation(self):
remote_runner = DataflowRunner()
- p = Pipeline(remote_runner,
- options=PipelineOptions(self.default_properties))
+ with Pipeline(
+ remote_runner,
+ options=PipelineOptions(self.default_properties)) as p:
- (p | ptransform.Create([1, 2, 3]) # pylint: disable=expression-not-assigned
- | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
- | ptransform.GroupByKey())
- p.run()
+ (p | ptransform.Create([1, 2, 3]) # pylint: disable=expression-not-assigned
+ | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
+ | ptransform.GroupByKey())
def test_streaming_create_translation(self):
remote_runner = DataflowRunner()
self.default_properties.append("--streaming")
- p = Pipeline(remote_runner, PipelineOptions(self.default_properties))
- p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
- p.run()
+ with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p:
+ p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
job_dict = json.loads(str(remote_runner.job))
self.assertEqual(len(job_dict[u'steps']), 3)
@@ -252,11 +251,12 @@
def test_biqquery_read_streaming_fail(self):
remote_runner = DataflowRunner()
self.default_properties.append("--streaming")
- p = Pipeline(remote_runner, PipelineOptions(self.default_properties))
- _ = p | beam.io.Read(beam.io.BigQuerySource('some.table'))
with self.assertRaisesRegex(ValueError,
r'source is not currently available'):
- p.run()
+ with Pipeline(
+ remote_runner,
+ PipelineOptions(self.default_properties)) as p:
+ _ = p | beam.io.Read(beam.io.BigQuerySource('some.table'))
# TODO(BEAM-8095): Segfaults in Python 3.7 with xdist.
@pytest.mark.no_xdist
@@ -422,9 +422,8 @@
remote_runner = DataflowRunner()
self.default_properties.append('--min_cpu_platform=Intel Haswell')
- p = Pipeline(remote_runner, PipelineOptions(self.default_properties))
- p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
- p.run()
+ with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p:
+ p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
self.assertIn('min_cpu_platform=Intel Haswell',
remote_runner.job.options.view_as(DebugOptions).experiments)
@@ -434,9 +433,8 @@
self.default_properties.append('--enable_streaming_engine')
self.default_properties.append('--experiment=some_other_experiment')
- p = Pipeline(remote_runner, PipelineOptions(self.default_properties))
- p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
- p.run()
+ with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p:
+ p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
experiments_for_job = (
remote_runner.job.options.view_as(DebugOptions).experiments)
@@ -449,9 +447,8 @@
self.default_properties.append('--experiment=some_other_experiment')
self.default_properties.append('--dataflow_worker_jar=test.jar')
- p = Pipeline(remote_runner, PipelineOptions(self.default_properties))
- p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
- p.run()
+ with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p:
+ p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
experiments_for_job = (
remote_runner.job.options.view_as(DebugOptions).experiments)
@@ -463,9 +460,8 @@
self.default_properties.append('--experiment=beam_fn_api')
self.default_properties.append('--dataflow_worker_jar=test.jar')
- p = Pipeline(remote_runner, PipelineOptions(self.default_properties))
- p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
- p.run()
+ with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p:
+ p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
experiments_for_job = (
remote_runner.job.options.view_as(DebugOptions).experiments)
@@ -475,9 +471,8 @@
def test_use_fastavro_experiment_is_added_on_py3_and_onwards(self):
remote_runner = DataflowRunner()
- p = Pipeline(remote_runner, PipelineOptions(self.default_properties))
- p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
- p.run()
+ with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p:
+ p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
self.assertEqual(
sys.version_info[0] > 2,
@@ -488,9 +483,8 @@
remote_runner = DataflowRunner()
self.default_properties.append('--experiment=use_avro')
- p = Pipeline(remote_runner, PipelineOptions(self.default_properties))
- p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
- p.run()
+ with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p:
+ p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
debug_options = remote_runner.job.options.view_as(DebugOptions)
diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py b/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py
index f3f2e75..ec69ee1 100644
--- a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py
@@ -185,10 +185,9 @@
def Write(self, value):
self.written_values.append(value)
- p = TestPipeline()
- sink = FakeSink()
- p | Create(['a', 'b', 'c']) | _NativeWrite(sink) # pylint: disable=expression-not-assigned
- p.run()
+ with TestPipeline() as p:
+ sink = FakeSink()
+ p | Create(['a', 'b', 'c']) | _NativeWrite(sink) # pylint: disable=expression-not-assigned
self.assertEqual(['a', 'b', 'c'], sorted(sink.written_values))
diff --git a/sdks/python/apache_beam/runners/dataflow/template_runner_test.py b/sdks/python/apache_beam/runners/dataflow/template_runner_test.py
index e6d0d66..9edb2db 100644
--- a/sdks/python/apache_beam/runners/dataflow/template_runner_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/template_runner_test.py
@@ -54,19 +54,19 @@
dummy_dir = tempfile.mkdtemp()
remote_runner = DataflowRunner()
- pipeline = Pipeline(remote_runner,
- options=PipelineOptions([
- '--dataflow_endpoint=ignored',
- '--sdk_location=' + dummy_file_name,
- '--job_name=test-job',
- '--project=test-project',
- '--staging_location=' + dummy_dir,
- '--temp_location=/dev/null',
- '--template_location=' + dummy_file_name,
- '--no_auth']))
+ with Pipeline(remote_runner,
+ options=PipelineOptions([
+ '--dataflow_endpoint=ignored',
+ '--sdk_location=' + dummy_file_name,
+ '--job_name=test-job',
+ '--project=test-project',
+ '--staging_location=' + dummy_dir,
+ '--temp_location=/dev/null',
+ '--template_location=' + dummy_file_name,
+ '--no_auth'])) as pipeline:
- pipeline | beam.Create([1, 2, 3]) | beam.Map(lambda x: x) # pylint: disable=expression-not-assigned
- pipeline.run().wait_until_finish()
+ pipeline | beam.Create([1, 2, 3]) | beam.Map(lambda x: x) # pylint: disable=expression-not-assigned
+
with open(dummy_file_name) as template_file:
saved_job_dict = json.load(template_file)
self.assertEqual(
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam.py b/sdks/python/apache_beam/runners/interactive/interactive_beam.py
index 3769dd7..d35c1d8 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_beam.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_beam.py
@@ -65,10 +65,9 @@
class Foo(object)
def run_pipeline(self):
- p = beam.Pipeline()
- init_pcoll = p | 'Init Create' >> beam.Create(range(10))
- watch(locals())
- p.run()
+ with beam.Pipeline() as p:
+ init_pcoll = p | 'Init Create' >> beam.Create(range(10))
+ watch(locals())
return init_pcoll
init_pcoll = Foo().run_pipeline()
diff --git a/sdks/python/apache_beam/testing/test_pipeline.py b/sdks/python/apache_beam/testing/test_pipeline.py
index 26819e5..0ca81ec 100644
--- a/sdks/python/apache_beam/testing/test_pipeline.py
+++ b/sdks/python/apache_beam/testing/test_pipeline.py
@@ -57,10 +57,9 @@
For example, use assert_that for test validation::
- pipeline = TestPipeline()
- pcoll = ...
- assert_that(pcoll, equal_to(...))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = ...
+ assert_that(pcoll, equal_to(...))
"""
def __init__(self,
diff --git a/sdks/python/apache_beam/testing/test_pipeline_test.py b/sdks/python/apache_beam/testing/test_pipeline_test.py
index 8cd4c88..59779cd 100644
--- a/sdks/python/apache_beam/testing/test_pipeline_test.py
+++ b/sdks/python/apache_beam/testing/test_pipeline_test.py
@@ -106,16 +106,16 @@
self.assertEqual(test_pipeline.get_option(name), value)
def test_skip_IT(self):
- test_pipeline = TestPipeline(is_integration_test=True)
- test_pipeline.run()
- # Note that this will never be reached since it should be skipped above.
+ with TestPipeline(is_integration_test=True) as _:
+ # Note that this will never be reached since it should be skipped above.
+ pass
self.fail()
@mock.patch('apache_beam.testing.test_pipeline.Pipeline.run', autospec=True)
def test_not_use_test_runner_api(self, mock_run):
- test_pipeline = TestPipeline(argv=['--not-use-test-runner-api'],
- blocking=False)
- test_pipeline.run()
+ with TestPipeline(argv=['--not-use-test-runner-api'],
+ blocking=False) as test_pipeline:
+ pass
mock_run.assert_called_once_with(test_pipeline, test_runner_api=False)
diff --git a/sdks/python/apache_beam/testing/test_stream_test.py b/sdks/python/apache_beam/testing/test_stream_test.py
index 0aefbcb..ba599bd 100644
--- a/sdks/python/apache_beam/testing/test_stream_test.py
+++ b/sdks/python/apache_beam/testing/test_stream_test.py
@@ -114,20 +114,19 @@
options = PipelineOptions()
options.view_as(StandardOptions).streaming = True
- p = TestPipeline(options=options)
- my_record_fn = RecordFn()
- records = p | test_stream | beam.ParDo(my_record_fn)
+ with TestPipeline(options=options) as p:
+ my_record_fn = RecordFn()
+ records = p | test_stream | beam.ParDo(my_record_fn)
- assert_that(records, equal_to([
- ('a', timestamp.Timestamp(10)),
- ('b', timestamp.Timestamp(10)),
- ('c', timestamp.Timestamp(10)),
- ('d', timestamp.Timestamp(20)),
- ('e', timestamp.Timestamp(20)),
- ('late', timestamp.Timestamp(12)),
- ('last', timestamp.Timestamp(310)),]))
+ assert_that(records, equal_to([
+ ('a', timestamp.Timestamp(10)),
+ ('b', timestamp.Timestamp(10)),
+ ('c', timestamp.Timestamp(10)),
+ ('d', timestamp.Timestamp(20)),
+ ('e', timestamp.Timestamp(20)),
+ ('late', timestamp.Timestamp(12)),
+ ('last', timestamp.Timestamp(310)),]))
- p.run()
def test_multiple_outputs(self):
"""Tests that the TestStream supports emitting to multiple PCollections."""
@@ -418,33 +417,31 @@
def test_basic_execution_sideinputs(self):
options = PipelineOptions()
options.view_as(StandardOptions).streaming = True
- p = TestPipeline(options=options)
+ with TestPipeline(options=options) as p:
- main_stream = (p
- | 'main TestStream' >> TestStream()
- .advance_watermark_to(10)
- .add_elements(['e']))
- side_stream = (p
- | 'side TestStream' >> TestStream()
- .add_elements([window.TimestampedValue(2, 2)])
- .add_elements([window.TimestampedValue(1, 1)])
- .add_elements([window.TimestampedValue(7, 7)])
- .add_elements([window.TimestampedValue(4, 4)])
- )
+ main_stream = (p
+ | 'main TestStream' >> TestStream()
+ .advance_watermark_to(10)
+ .add_elements(['e']))
+ side_stream = (p
+ | 'side TestStream' >> TestStream()
+ .add_elements([window.TimestampedValue(2, 2)])
+ .add_elements([window.TimestampedValue(1, 1)])
+ .add_elements([window.TimestampedValue(7, 7)])
+ .add_elements([window.TimestampedValue(4, 4)])
+ )
- class RecordFn(beam.DoFn):
- def process(self,
- elm=beam.DoFn.ElementParam,
- ts=beam.DoFn.TimestampParam,
- side=beam.DoFn.SideInputParam):
- yield (elm, ts, side)
+ class RecordFn(beam.DoFn):
+ def process(self,
+ elm=beam.DoFn.ElementParam,
+ ts=beam.DoFn.TimestampParam,
+ side=beam.DoFn.SideInputParam):
+ yield (elm, ts, side)
- records = (main_stream # pylint: disable=unused-variable
- | beam.ParDo(RecordFn(), beam.pvalue.AsList(side_stream)))
+ records = (main_stream # pylint: disable=unused-variable
+ | beam.ParDo(RecordFn(), beam.pvalue.AsList(side_stream)))
- assert_that(records, equal_to([('e', Timestamp(10), [2, 1, 7, 4])]))
-
- p.run()
+ assert_that(records, equal_to([('e', Timestamp(10), [2, 1, 7, 4])]))
def test_basic_execution_batch_sideinputs_fixed_windows(self):
options = PipelineOptions()
diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py
index 1111778..412f079 100644
--- a/sdks/python/apache_beam/transforms/combiners_test.py
+++ b/sdks/python/apache_beam/transforms/combiners_test.py
@@ -95,86 +95,83 @@
class CombineTest(unittest.TestCase):
def test_builtin_combines(self):
- pipeline = TestPipeline()
+ with TestPipeline() as pipeline:
- vals = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
- mean = sum(vals) / float(len(vals))
- size = len(vals)
+ vals = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
+ mean = sum(vals) / float(len(vals))
+ size = len(vals)
- # First for global combines.
- pcoll = pipeline | 'start' >> Create(vals)
- result_mean = pcoll | 'mean' >> combine.Mean.Globally()
- result_count = pcoll | 'count' >> combine.Count.Globally()
- assert_that(result_mean, equal_to([mean]), label='assert:mean')
- assert_that(result_count, equal_to([size]), label='assert:size')
+ # First for global combines.
+ pcoll = pipeline | 'start' >> Create(vals)
+ result_mean = pcoll | 'mean' >> combine.Mean.Globally()
+ result_count = pcoll | 'count' >> combine.Count.Globally()
+ assert_that(result_mean, equal_to([mean]), label='assert:mean')
+ assert_that(result_count, equal_to([size]), label='assert:size')
- # Again for per-key combines.
- pcoll = pipeline | 'start-perkey' >> Create([('a', x) for x in vals])
- result_key_mean = pcoll | 'mean-perkey' >> combine.Mean.PerKey()
- result_key_count = pcoll | 'count-perkey' >> combine.Count.PerKey()
- assert_that(result_key_mean, equal_to([('a', mean)]), label='key:mean')
- assert_that(result_key_count, equal_to([('a', size)]), label='key:size')
- pipeline.run()
+ # Again for per-key combines.
+ pcoll = pipeline | 'start-perkey' >> Create([('a', x) for x in vals])
+ result_key_mean = pcoll | 'mean-perkey' >> combine.Mean.PerKey()
+ result_key_count = pcoll | 'count-perkey' >> combine.Count.PerKey()
+ assert_that(result_key_mean, equal_to([('a', mean)]), label='key:mean')
+ assert_that(result_key_count, equal_to([('a', size)]), label='key:size')
def test_top(self):
- pipeline = TestPipeline()
+ with TestPipeline() as pipeline:
- # First for global combines.
- pcoll = pipeline | 'start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6])
- result_top = pcoll | 'top' >> combine.Top.Largest(5)
- result_bot = pcoll | 'bot' >> combine.Top.Smallest(4)
- assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='assert:top')
- assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='assert:bot')
+ # First for global combines.
+ pcoll = pipeline | 'start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6])
+ result_top = pcoll | 'top' >> combine.Top.Largest(5)
+ result_bot = pcoll | 'bot' >> combine.Top.Smallest(4)
+ assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='assert:top')
+ assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='assert:bot')
- # Again for per-key combines.
- pcoll = pipeline | 'start-perkey' >> Create(
- [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]])
- result_key_top = pcoll | 'top-perkey' >> combine.Top.LargestPerKey(5)
- result_key_bot = pcoll | 'bot-perkey' >> combine.Top.SmallestPerKey(4)
- assert_that(result_key_top, equal_to([('a', [9, 6, 6, 5, 3])]),
- label='key:top')
- assert_that(result_key_bot, equal_to([('a', [0, 1, 1, 1])]),
- label='key:bot')
- pipeline.run()
+ # Again for per-key combines.
+ pcoll = pipeline | 'start-perkey' >> Create(
+ [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]])
+ result_key_top = pcoll | 'top-perkey' >> combine.Top.LargestPerKey(5)
+ result_key_bot = pcoll | 'bot-perkey' >> combine.Top.SmallestPerKey(4)
+ assert_that(result_key_top, equal_to([('a', [9, 6, 6, 5, 3])]),
+ label='key:top')
+ assert_that(result_key_bot, equal_to([('a', [0, 1, 1, 1])]),
+ label='key:bot')
@unittest.skipIf(sys.version_info[0] > 2, 'deprecated comparator')
def test_top_py2(self):
- pipeline = TestPipeline()
+ with TestPipeline() as pipeline:
- # A parameter we'll be sharing with a custom comparator.
- names = {0: 'zo',
- 1: 'one',
- 2: 'twoo',
- 3: 'three',
- 5: 'fiiive',
- 6: 'sssssix',
- 9: 'nniiinne'}
+ # A parameter we'll be sharing with a custom comparator.
+ names = {0: 'zo',
+ 1: 'one',
+ 2: 'twoo',
+ 3: 'three',
+ 5: 'fiiive',
+ 6: 'sssssix',
+ 9: 'nniiinne'}
- # First for global combines.
- pcoll = pipeline | 'start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6])
+ # First for global combines.
+ pcoll = pipeline | 'start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6])
- result_cmp = pcoll | 'cmp' >> combine.Top.Of(
- 6,
- lambda a, b, names: len(names[a]) < len(names[b]),
- names) # Note parameter passed to comparator.
- result_cmp_rev = pcoll | 'cmp_rev' >> combine.Top.Of(
- 3,
- lambda a, b, names: len(names[a]) < len(names[b]),
- names, # Note parameter passed to comparator.
- reverse=True)
- assert_that(result_cmp, equal_to([[9, 6, 6, 5, 3, 2]]), label='assert:cmp')
- assert_that(result_cmp_rev, equal_to([[0, 1, 1]]), label='assert:cmp_rev')
+ result_cmp = pcoll | 'cmp' >> combine.Top.Of(
+ 6,
+ lambda a, b, names: len(names[a]) < len(names[b]),
+ names) # Note parameter passed to comparator.
+ result_cmp_rev = pcoll | 'cmp_rev' >> combine.Top.Of(
+ 3,
+ lambda a, b, names: len(names[a]) < len(names[b]),
+ names, # Note parameter passed to comparator.
+ reverse=True)
+ assert_that(result_cmp, equal_to([[9, 6, 6, 5, 3, 2]]), label='CheckCmp')
+ assert_that(result_cmp_rev, equal_to([[0, 1, 1]]), label='CheckCmpRev')
- # Again for per-key combines.
- pcoll = pipeline | 'start-perkye' >> Create(
- [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]])
- result_key_cmp = pcoll | 'cmp-perkey' >> combine.Top.PerKey(
- 6,
- lambda a, b, names: len(names[a]) < len(names[b]),
- names) # Note parameter passed to comparator.
- assert_that(result_key_cmp, equal_to([('a', [9, 6, 6, 5, 3, 2])]),
- label='key:cmp')
- pipeline.run()
+ # Again for per-key combines.
+ pcoll = pipeline | 'start-perkye' >> Create(
+ [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]])
+ result_key_cmp = pcoll | 'cmp-perkey' >> combine.Top.PerKey(
+ 6,
+ lambda a, b, names: len(names[a]) < len(names[b]),
+ names) # Note parameter passed to comparator.
+ assert_that(result_key_cmp, equal_to([('a', [9, 6, 6, 5, 3, 2])]),
+ label='key:cmp')
def test_empty_global_top(self):
with TestPipeline() as p:
@@ -185,12 +182,11 @@
elements = list(range(100))
random.shuffle(elements)
- pipeline = TestPipeline()
- shards = [pipeline | 'Shard%s' % shard >> beam.Create(elements[shard::7])
- for shard in range(7)]
- assert_that(shards | beam.Flatten() | combine.Top.Largest(10),
- equal_to([[99, 98, 97, 96, 95, 94, 93, 92, 91, 90]]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ shards = [pipeline | 'Shard%s' % shard >> beam.Create(elements[shard::7])
+ for shard in range(7)]
+ assert_that(shards | beam.Flatten() | combine.Top.Largest(10),
+ equal_to([[99, 98, 97, 96, 95, 94, 93, 92, 91, 90]]))
def test_top_key(self):
self.assertEqual(
@@ -272,22 +268,22 @@
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_top_shorthands(self):
- pipeline = TestPipeline()
+ with TestPipeline() as pipeline:
- pcoll = pipeline | 'start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6])
- result_top = pcoll | 'top' >> beam.CombineGlobally(combine.Largest(5))
- result_bot = pcoll | 'bot' >> beam.CombineGlobally(combine.Smallest(4))
- assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='assert:top')
- assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='assert:bot')
+ pcoll = pipeline | 'start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6])
+ result_top = pcoll | 'top' >> beam.CombineGlobally(combine.Largest(5))
+ result_bot = pcoll | 'bot' >> beam.CombineGlobally(combine.Smallest(4))
+ assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='assert:top')
+ assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='assert:bot')
- pcoll = pipeline | 'start-perkey' >> Create(
- [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]])
- result_ktop = pcoll | 'top-perkey' >> beam.CombinePerKey(combine.Largest(5))
- result_kbot = pcoll | 'bot-perkey' >> beam.CombinePerKey(
- combine.Smallest(4))
- assert_that(result_ktop, equal_to([('a', [9, 6, 6, 5, 3])]), label='k:top')
- assert_that(result_kbot, equal_to([('a', [0, 1, 1, 1])]), label='k:bot')
- pipeline.run()
+ pcoll = pipeline | 'start-perkey' >> Create(
+ [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]])
+ result_ktop = pcoll | 'top-perkey' >> beam.CombinePerKey(
+ combine.Largest(5))
+ result_kbot = pcoll | 'bot-perkey' >> beam.CombinePerKey(
+ combine.Smallest(4))
+ assert_that(result_ktop, equal_to([('a', [9, 6, 6, 5, 3])]), label='ktop')
+ assert_that(result_kbot, equal_to([('a', [0, 1, 1, 1])]), label='kbot')
def test_top_no_compact(self):
@@ -296,24 +292,23 @@
def compact(self, accumulator):
return accumulator
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6])
- result_top = pcoll | 'Top' >> beam.CombineGlobally(
- TopCombineFnNoCompact(5, key=lambda x: x))
- result_bot = pcoll | 'Bot' >> beam.CombineGlobally(
- TopCombineFnNoCompact(4, reverse=True))
- assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='Assert:Top')
- assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='Assert:Bot')
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6])
+ result_top = pcoll | 'Top' >> beam.CombineGlobally(
+ TopCombineFnNoCompact(5, key=lambda x: x))
+ result_bot = pcoll | 'Bot' >> beam.CombineGlobally(
+ TopCombineFnNoCompact(4, reverse=True))
+ assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='Assert:Top')
+ assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='Assert:Bot')
- pcoll = pipeline | 'Start-Perkey' >> Create(
- [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]])
- result_ktop = pcoll | 'Top-PerKey' >> beam.CombinePerKey(
- TopCombineFnNoCompact(5, key=lambda x: x))
- result_kbot = pcoll | 'Bot-PerKey' >> beam.CombinePerKey(
- TopCombineFnNoCompact(4, reverse=True))
- assert_that(result_ktop, equal_to([('a', [9, 6, 6, 5, 3])]), label='K:Top')
- assert_that(result_kbot, equal_to([('a', [0, 1, 1, 1])]), label='K:Bot')
- pipeline.run()
+ pcoll = pipeline | 'Start-Perkey' >> Create(
+ [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]])
+ result_ktop = pcoll | 'Top-PerKey' >> beam.CombinePerKey(
+ TopCombineFnNoCompact(5, key=lambda x: x))
+ result_kbot = pcoll | 'Bot-PerKey' >> beam.CombinePerKey(
+ TopCombineFnNoCompact(4, reverse=True))
+ assert_that(result_ktop, equal_to([('a', [9, 6, 6, 5, 3])]), label='KTop')
+ assert_that(result_kbot, equal_to([('a', [0, 1, 1, 1])]), label='KBot')
def test_global_sample(self):
def is_good_sample(actual):
@@ -329,21 +324,20 @@
label='check-%d' % ix)
def test_per_key_sample(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'start-perkey' >> Create(
- sum(([(i, 1), (i, 1), (i, 2), (i, 2)] for i in range(9)), []))
- result = pcoll | 'sample' >> combine.Sample.FixedSizePerKey(3)
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'start-perkey' >> Create(
+ sum(([(i, 1), (i, 1), (i, 2), (i, 2)] for i in range(9)), []))
+ result = pcoll | 'sample' >> combine.Sample.FixedSizePerKey(3)
- def matcher():
- def match(actual):
- for _, samples in actual:
- equal_to([3])([len(samples)])
- num_ones = sum(1 for x in samples if x == 1)
- num_twos = sum(1 for x in samples if x == 2)
- equal_to([1, 2])([num_ones, num_twos])
- return match
- assert_that(result, matcher())
- pipeline.run()
+ def matcher():
+ def match(actual):
+ for _, samples in actual:
+ equal_to([3])([len(samples)])
+ num_ones = sum(1 for x in samples if x == 1)
+ num_twos = sum(1 for x in samples if x == 2)
+ equal_to([1, 2])([num_ones, num_twos])
+ return match
+ assert_that(result, matcher())
def test_tuple_combine_fn(self):
with TestPipeline() as p:
@@ -365,30 +359,28 @@
assert_that(result, equal_to([(1, 7.0 / 4, 3)]))
def test_to_list_and_to_dict(self):
- pipeline = TestPipeline()
- the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
- pcoll = pipeline | 'start' >> Create(the_list)
- result = pcoll | 'to list' >> combine.ToList()
+ with TestPipeline() as pipeline:
+ the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
+ pcoll = pipeline | 'start' >> Create(the_list)
+ result = pcoll | 'to list' >> combine.ToList()
- def matcher(expected):
- def match(actual):
- equal_to(expected[0])(actual[0])
- return match
- assert_that(result, matcher([the_list]))
- pipeline.run()
+ def matcher(expected):
+ def match(actual):
+ equal_to(expected[0])(actual[0])
+ return match
+ assert_that(result, matcher([the_list]))
- pipeline = TestPipeline()
- pairs = [(1, 2), (3, 4), (5, 6)]
- pcoll = pipeline | 'start-pairs' >> Create(pairs)
- result = pcoll | 'to dict' >> combine.ToDict()
+ with TestPipeline() as pipeline:
+ pairs = [(1, 2), (3, 4), (5, 6)]
+ pcoll = pipeline | 'start-pairs' >> Create(pairs)
+ result = pcoll | 'to dict' >> combine.ToDict()
- def matcher():
- def match(actual):
- equal_to([1])([len(actual)])
- equal_to(pairs)(actual[0].items())
- return match
- assert_that(result, matcher())
- pipeline.run()
+ def matcher():
+ def match(actual):
+ equal_to([1])([len(actual)])
+ equal_to(pairs)(actual[0].items())
+ return match
+ assert_that(result, matcher())
def test_combine_globally_with_default(self):
with TestPipeline() as p:
diff --git a/sdks/python/apache_beam/transforms/dofn_lifecycle_test.py b/sdks/python/apache_beam/transforms/dofn_lifecycle_test.py
index 0f2cc4c..bcba20d 100644
--- a/sdks/python/apache_beam/transforms/dofn_lifecycle_test.py
+++ b/sdks/python/apache_beam/transforms/dofn_lifecycle_test.py
@@ -78,12 +78,10 @@
@attr('ValidatesRunner')
class DoFnLifecycleTest(unittest.TestCase):
def test_dofn_lifecycle(self):
- p = TestPipeline()
- _ = (p
- | 'Start' >> beam.Create([1, 2, 3])
- | 'Do' >> beam.ParDo(CallSequenceEnforcingDoFn()))
- result = p.run()
- result.wait_until_finish()
+ with TestPipeline() as p:
+ _ = (p
+ | 'Start' >> beam.Create([1, 2, 3])
+ | 'Do' >> beam.ParDo(CallSequenceEnforcingDoFn()))
# Assumes that the worker is run in the same process as the test.
diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py
index f85a2b9..ef3932e 100644
--- a/sdks/python/apache_beam/transforms/ptransform_test.py
+++ b/sdks/python/apache_beam/transforms/ptransform_test.py
@@ -118,11 +118,10 @@
def process(self, element, addon):
return [element + addon]
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
- result = pcoll | 'Do' >> beam.ParDo(AddNDoFn(), 10)
- assert_that(result, equal_to([11, 12, 13]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
+ result = pcoll | 'Do' >> beam.ParDo(AddNDoFn(), 10)
+ assert_that(result, equal_to([11, 12, 13]))
def test_do_with_unconstructed_do_fn(self):
class MyDoFn(beam.DoFn):
@@ -130,81 +129,74 @@
def process(self):
pass
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
with self.assertRaises(ValueError):
- pcoll | 'Do' >> beam.ParDo(MyDoFn) # Note the lack of ()'s
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
+ pcoll | 'Do' >> beam.ParDo(MyDoFn) # Note the lack of ()'s
def test_do_with_callable(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
- result = pcoll | 'Do' >> beam.FlatMap(lambda x, addon: [x + addon], 10)
- assert_that(result, equal_to([11, 12, 13]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
+ result = pcoll | 'Do' >> beam.FlatMap(lambda x, addon: [x + addon], 10)
+ assert_that(result, equal_to([11, 12, 13]))
def test_do_with_side_input_as_arg(self):
- pipeline = TestPipeline()
- side = pipeline | 'Side' >> beam.Create([10])
- pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
- result = pcoll | 'Do' >> beam.FlatMap(
- lambda x, addon: [x + addon], pvalue.AsSingleton(side))
- assert_that(result, equal_to([11, 12, 13]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ side = pipeline | 'Side' >> beam.Create([10])
+ pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
+ result = pcoll | 'Do' >> beam.FlatMap(
+ lambda x, addon: [x + addon], pvalue.AsSingleton(side))
+ assert_that(result, equal_to([11, 12, 13]))
def test_do_with_side_input_as_keyword_arg(self):
- pipeline = TestPipeline()
- side = pipeline | 'Side' >> beam.Create([10])
- pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
- result = pcoll | 'Do' >> beam.FlatMap(
- lambda x, addon: [x + addon], addon=pvalue.AsSingleton(side))
- assert_that(result, equal_to([11, 12, 13]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ side = pipeline | 'Side' >> beam.Create([10])
+ pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
+ result = pcoll | 'Do' >> beam.FlatMap(
+ lambda x, addon: [x + addon], addon=pvalue.AsSingleton(side))
+ assert_that(result, equal_to([11, 12, 13]))
def test_do_with_do_fn_returning_string_raises_warning(self):
- pipeline = TestPipeline()
- pipeline._options.view_as(TypeOptions).runtime_type_check = True
- pcoll = pipeline | 'Start' >> beam.Create(['2', '9', '3'])
- pcoll | 'Do' >> beam.FlatMap(lambda x: x + '1')
-
- # Since the DoFn directly returns a string we should get an error warning
- # us.
with self.assertRaises(typehints.TypeCheckError) as cm:
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pipeline._options.view_as(TypeOptions).runtime_type_check = True
+ pcoll = pipeline | 'Start' >> beam.Create(['2', '9', '3'])
+ pcoll | 'Do' >> beam.FlatMap(lambda x: x + '1')
+
+ # Since the DoFn directly returns a string we should get an
+ # error warning us when the pipeliene runs.
expected_error_prefix = ('Returning a str from a ParDo or FlatMap '
'is discouraged.')
self.assertStartswith(cm.exception.args[0], expected_error_prefix)
def test_do_with_do_fn_returning_dict_raises_warning(self):
- pipeline = TestPipeline()
- pipeline._options.view_as(TypeOptions).runtime_type_check = True
- pcoll = pipeline | 'Start' >> beam.Create(['2', '9', '3'])
- pcoll | 'Do' >> beam.FlatMap(lambda x: {x: '1'})
-
- # Since the DoFn directly returns a dict we should get an error warning
- # us.
with self.assertRaises(typehints.TypeCheckError) as cm:
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pipeline._options.view_as(TypeOptions).runtime_type_check = True
+ pcoll = pipeline | 'Start' >> beam.Create(['2', '9', '3'])
+ pcoll | 'Do' >> beam.FlatMap(lambda x: {x: '1'})
+
+ # Since the DoFn directly returns a dict we should get an error warning
+ # us when the pipeliene runs.
expected_error_prefix = ('Returning a dict from a ParDo or FlatMap '
'is discouraged.')
self.assertStartswith(cm.exception.args[0], expected_error_prefix)
def test_do_with_multiple_outputs_maintains_unique_name(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
- r1 = pcoll | 'A' >> beam.FlatMap(lambda x: [x + 1]).with_outputs(main='m')
- r2 = pcoll | 'B' >> beam.FlatMap(lambda x: [x + 2]).with_outputs(main='m')
- assert_that(r1.m, equal_to([2, 3, 4]), label='r1')
- assert_that(r2.m, equal_to([3, 4, 5]), label='r2')
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
+ r1 = pcoll | 'A' >> beam.FlatMap(lambda x: [x + 1]).with_outputs(main='m')
+ r2 = pcoll | 'B' >> beam.FlatMap(lambda x: [x + 2]).with_outputs(main='m')
+ assert_that(r1.m, equal_to([2, 3, 4]), label='r1')
+ assert_that(r2.m, equal_to([3, 4, 5]), label='r2')
@attr('ValidatesRunner')
def test_impulse(self):
- pipeline = TestPipeline()
- result = pipeline | beam.Impulse() | beam.Map(lambda _: 0)
- assert_that(result, equal_to([0]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ result = pipeline | beam.Impulse() | beam.Map(lambda _: 0)
+ assert_that(result, equal_to([0]))
# TODO(BEAM-3544): Disable this test in streaming temporarily.
# Remove sickbay-streaming tag after it's resolved.
@@ -246,14 +238,13 @@
else:
yield pvalue.TaggedOutput('odd', element)
- pipeline = TestPipeline()
- nums = pipeline | 'Some Numbers' >> beam.Create([1, 2, 3, 4])
- results = nums | 'ClassifyNumbers' >> beam.ParDo(
- SomeDoFn()).with_outputs('odd', 'even', main='main')
- assert_that(results.main, equal_to([1, 2, 3, 4]))
- assert_that(results.odd, equal_to([1, 3]), label='assert:odd')
- assert_that(results.even, equal_to([2, 4]), label='assert:even')
- pipeline.run()
+ with TestPipeline() as pipeline:
+ nums = pipeline | 'Some Numbers' >> beam.Create([1, 2, 3, 4])
+ results = nums | 'ClassifyNumbers' >> beam.ParDo(
+ SomeDoFn()).with_outputs('odd', 'even', main='main')
+ assert_that(results.main, equal_to([1, 2, 3, 4]))
+ assert_that(results.odd, equal_to([1, 3]), label='assert:odd')
+ assert_that(results.even, equal_to([2, 4]), label='assert:even')
@attr('ValidatesRunner')
def test_par_do_with_multiple_outputs_and_using_return(self):
@@ -262,55 +253,51 @@
return [v, pvalue.TaggedOutput('even', v)]
return [v, pvalue.TaggedOutput('odd', v)]
- pipeline = TestPipeline()
- nums = pipeline | 'Some Numbers' >> beam.Create([1, 2, 3, 4])
- results = nums | 'ClassifyNumbers' >> beam.FlatMap(
- some_fn).with_outputs('odd', 'even', main='main')
- assert_that(results.main, equal_to([1, 2, 3, 4]))
- assert_that(results.odd, equal_to([1, 3]), label='assert:odd')
- assert_that(results.even, equal_to([2, 4]), label='assert:even')
- pipeline.run()
+ with TestPipeline() as pipeline:
+ nums = pipeline | 'Some Numbers' >> beam.Create([1, 2, 3, 4])
+ results = nums | 'ClassifyNumbers' >> beam.FlatMap(
+ some_fn).with_outputs('odd', 'even', main='main')
+ assert_that(results.main, equal_to([1, 2, 3, 4]))
+ assert_that(results.odd, equal_to([1, 3]), label='assert:odd')
+ assert_that(results.even, equal_to([2, 4]), label='assert:even')
@attr('ValidatesRunner')
def test_undeclared_outputs(self):
- pipeline = TestPipeline()
- nums = pipeline | 'Some Numbers' >> beam.Create([1, 2, 3, 4])
- results = nums | 'ClassifyNumbers' >> beam.FlatMap(
- lambda x: [x,
- pvalue.TaggedOutput('even' if x % 2 == 0 else 'odd', x),
- pvalue.TaggedOutput('extra', x)]
- ).with_outputs()
- assert_that(results[None], equal_to([1, 2, 3, 4]))
- assert_that(results.odd, equal_to([1, 3]), label='assert:odd')
- assert_that(results.even, equal_to([2, 4]), label='assert:even')
- pipeline.run()
+ with TestPipeline() as pipeline:
+ nums = pipeline | 'Some Numbers' >> beam.Create([1, 2, 3, 4])
+ results = nums | 'ClassifyNumbers' >> beam.FlatMap(
+ lambda x: [x,
+ pvalue.TaggedOutput('even' if x % 2 == 0 else 'odd', x),
+ pvalue.TaggedOutput('extra', x)]
+ ).with_outputs()
+ assert_that(results[None], equal_to([1, 2, 3, 4]))
+ assert_that(results.odd, equal_to([1, 3]), label='assert:odd')
+ assert_that(results.even, equal_to([2, 4]), label='assert:even')
@attr('ValidatesRunner')
def test_multiple_empty_outputs(self):
- pipeline = TestPipeline()
- nums = pipeline | 'Some Numbers' >> beam.Create([1, 3, 5])
- results = nums | 'ClassifyNumbers' >> beam.FlatMap(
- lambda x: [x,
- pvalue.TaggedOutput('even' if x % 2 == 0 else 'odd', x)]
- ).with_outputs()
- assert_that(results[None], equal_to([1, 3, 5]))
- assert_that(results.odd, equal_to([1, 3, 5]), label='assert:odd')
- assert_that(results.even, equal_to([]), label='assert:even')
- pipeline.run()
+ with TestPipeline() as pipeline:
+ nums = pipeline | 'Some Numbers' >> beam.Create([1, 3, 5])
+ results = nums | 'ClassifyNumbers' >> beam.FlatMap(
+ lambda x: [x,
+ pvalue.TaggedOutput('even' if x % 2 == 0 else 'odd', x)]
+ ).with_outputs()
+ assert_that(results[None], equal_to([1, 3, 5]))
+ assert_that(results.odd, equal_to([1, 3, 5]), label='assert:odd')
+ assert_that(results.even, equal_to([]), label='assert:even')
def test_do_requires_do_fn_returning_iterable(self):
# This function is incorrect because it returns an object that isn't an
# iterable.
def incorrect_par_do_fn(x):
return x + 5
- pipeline = TestPipeline()
- pipeline._options.view_as(TypeOptions).runtime_type_check = True
- pcoll = pipeline | 'Start' >> beam.Create([2, 9, 3])
- pcoll | 'Do' >> beam.FlatMap(incorrect_par_do_fn)
- # It's a requirement that all user-defined functions to a ParDo return
- # an iterable.
with self.assertRaises(typehints.TypeCheckError) as cm:
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pipeline._options.view_as(TypeOptions).runtime_type_check = True
+ pcoll = pipeline | 'Start' >> beam.Create([2, 9, 3])
+ pcoll | 'Do' >> beam.FlatMap(incorrect_par_do_fn)
+ # It's a requirement that all user-defined functions to a ParDo return
+ # an iterable.
expected_error_prefix = 'FlatMap and ParDo must return an iterable.'
self.assertStartswith(cm.exception.args[0], expected_error_prefix)
@@ -323,19 +310,18 @@
def finish_bundle(self):
yield WindowedValue('finish', -1, [window.GlobalWindow()])
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
- result = pcoll | 'Do' >> beam.ParDo(MyDoFn())
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
+ result = pcoll | 'Do' >> beam.ParDo(MyDoFn())
- # May have many bundles, but each has a start and finish.
- def matcher():
- def match(actual):
- equal_to(['finish'])(list(set(actual)))
- equal_to([1])([actual.count('finish')])
- return match
+ # May have many bundles, but each has a start and finish.
+ def matcher():
+ def match(actual):
+ equal_to(['finish'])(list(set(actual)))
+ equal_to([1])([actual.count('finish')])
+ return match
- assert_that(result, matcher())
- pipeline.run()
+ assert_that(result, matcher())
def test_do_fn_with_windowing_in_finish_bundle(self):
windowfn = window.FixedWindows(2)
@@ -375,19 +361,18 @@
yield 'started'
self.state = 'process'
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
- result = pcoll | 'Do' >> beam.ParDo(MyDoFn())
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
+ result = pcoll | 'Do' >> beam.ParDo(MyDoFn())
- # May have many bundles, but each has a start and finish.
- def matcher():
- def match(actual):
- equal_to(['started'])(list(set(actual)))
- equal_to([1])([actual.count('started')])
- return match
+ # May have many bundles, but each has a start and finish.
+ def matcher():
+ def match(actual):
+ equal_to(['started'])(list(set(actual)))
+ equal_to([1])([actual.count('started')])
+ return match
- assert_that(result, matcher())
- pipeline.run()
+ assert_that(result, matcher())
def test_do_fn_with_start_error(self):
class MyDoFn(beam.DoFn):
@@ -397,17 +382,15 @@
def process(self, element):
pass
- pipeline = TestPipeline()
- pipeline | 'Start' >> beam.Create([1, 2, 3]) | 'Do' >> beam.ParDo(MyDoFn())
with self.assertRaises(RuntimeError):
- pipeline.run()
+ with TestPipeline() as p:
+ p | 'Start' >> beam.Create([1, 2, 3]) | 'Do' >> beam.ParDo(MyDoFn())
def test_filter(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3, 4])
- result = pcoll | 'Filter' >> beam.Filter(lambda x: x % 2 == 0)
- assert_that(result, equal_to([2, 4]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3, 4])
+ result = pcoll | 'Filter' >> beam.Filter(lambda x: x % 2 == 0)
+ assert_that(result, equal_to([2, 4]))
class _MeanCombineFn(beam.CombineFn):
@@ -430,68 +413,62 @@
def test_combine_with_combine_fn(self):
vals = [1, 2, 3, 4, 5, 6, 7]
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create(vals)
- result = pcoll | 'Mean' >> beam.CombineGlobally(self._MeanCombineFn())
- assert_that(result, equal_to([sum(vals) // len(vals)]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create(vals)
+ result = pcoll | 'Mean' >> beam.CombineGlobally(self._MeanCombineFn())
+ assert_that(result, equal_to([sum(vals) // len(vals)]))
def test_combine_with_callable(self):
vals = [1, 2, 3, 4, 5, 6, 7]
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create(vals)
- result = pcoll | beam.CombineGlobally(sum)
- assert_that(result, equal_to([sum(vals)]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create(vals)
+ result = pcoll | beam.CombineGlobally(sum)
+ assert_that(result, equal_to([sum(vals)]))
def test_combine_with_side_input_as_arg(self):
values = [1, 2, 3, 4, 5, 6, 7]
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create(values)
- divisor = pipeline | 'Divisor' >> beam.Create([2])
- result = pcoll | 'Max' >> beam.CombineGlobally(
- # Multiples of divisor only.
- lambda vals, d: max(v for v in vals if v % d == 0),
- pvalue.AsSingleton(divisor)).without_defaults()
- filt_vals = [v for v in values if v % 2 == 0]
- assert_that(result, equal_to([max(filt_vals)]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create(values)
+ divisor = pipeline | 'Divisor' >> beam.Create([2])
+ result = pcoll | 'Max' >> beam.CombineGlobally(
+ # Multiples of divisor only.
+ lambda vals, d: max(v for v in vals if v % d == 0),
+ pvalue.AsSingleton(divisor)).without_defaults()
+ filt_vals = [v for v in values if v % 2 == 0]
+ assert_that(result, equal_to([max(filt_vals)]))
def test_combine_per_key_with_combine_fn(self):
vals_1 = [1, 2, 3, 4, 5, 6, 7]
vals_2 = [2, 4, 6, 8, 10, 12, 14]
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create(([('a', x) for x in vals_1] +
- [('b', x) for x in vals_2]))
- result = pcoll | 'Mean' >> beam.CombinePerKey(self._MeanCombineFn())
- assert_that(result, equal_to([('a', sum(vals_1) // len(vals_1)),
- ('b', sum(vals_2) // len(vals_2))]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create(([('a', x) for x in vals_1] +
+ [('b', x) for x in vals_2]))
+ result = pcoll | 'Mean' >> beam.CombinePerKey(self._MeanCombineFn())
+ assert_that(result, equal_to([('a', sum(vals_1) // len(vals_1)),
+ ('b', sum(vals_2) // len(vals_2))]))
def test_combine_per_key_with_callable(self):
vals_1 = [1, 2, 3, 4, 5, 6, 7]
vals_2 = [2, 4, 6, 8, 10, 12, 14]
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create(([('a', x) for x in vals_1] +
- [('b', x) for x in vals_2]))
- result = pcoll | beam.CombinePerKey(sum)
- assert_that(result, equal_to([('a', sum(vals_1)), ('b', sum(vals_2))]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create(([('a', x) for x in vals_1] +
+ [('b', x) for x in vals_2]))
+ result = pcoll | beam.CombinePerKey(sum)
+ assert_that(result, equal_to([('a', sum(vals_1)), ('b', sum(vals_2))]))
def test_combine_per_key_with_side_input_as_arg(self):
vals_1 = [1, 2, 3, 4, 5, 6, 7]
vals_2 = [2, 4, 6, 8, 10, 12, 14]
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create(([('a', x) for x in vals_1] +
- [('b', x) for x in vals_2]))
- divisor = pipeline | 'Divisor' >> beam.Create([2])
- result = pcoll | beam.CombinePerKey(
- lambda vals, d: max(v for v in vals if v % d == 0),
- pvalue.AsSingleton(divisor)) # Multiples of divisor only.
- m_1 = max(v for v in vals_1 if v % 2 == 0)
- m_2 = max(v for v in vals_2 if v % 2 == 0)
- assert_that(result, equal_to([('a', m_1), ('b', m_2)]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create(([('a', x) for x in vals_1] +
+ [('b', x) for x in vals_2]))
+ divisor = pipeline | 'Divisor' >> beam.Create([2])
+ result = pcoll | beam.CombinePerKey(
+ lambda vals, d: max(v for v in vals if v % d == 0),
+ pvalue.AsSingleton(divisor)) # Multiples of divisor only.
+ m_1 = max(v for v in vals_1 if v % 2 == 0)
+ m_2 = max(v for v in vals_2 if v % 2 == 0)
+ assert_that(result, equal_to([('a', m_1), ('b', m_2)]))
def test_group_by_key(self):
pipeline = TestPipeline()
@@ -511,13 +488,12 @@
sum_val += sum(value_list)
return [(key, sum_val)]
- pipeline = TestPipeline()
- pcoll = pipeline | 'start' >> beam.Create(
- [(1, 1), (1, 2), (1, 3), (1, 4)])
- result = (pcoll | 'Group' >> beam.GroupByKey()
- | 'Reiteration-Sum' >> beam.ParDo(MyDoFn()))
- assert_that(result, equal_to([(1, 170)]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'start' >> beam.Create(
+ [(1, 1), (1, 2), (1, 3), (1, 4)])
+ result = (pcoll | 'Group' >> beam.GroupByKey()
+ | 'Reiteration-Sum' >> beam.ParDo(MyDoFn()))
+ assert_that(result, equal_to([(1, 170)]))
def test_partition_with_partition_fn(self):
@@ -526,36 +502,33 @@
def partition_for(self, element, num_partitions, offset):
return (element % 3) + offset
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create([0, 1, 2, 3, 4, 5, 6, 7, 8])
- # Attempt nominal partition operation.
- partitions = pcoll | 'Part 1' >> beam.Partition(SomePartitionFn(), 4, 1)
- assert_that(partitions[0], equal_to([]))
- assert_that(partitions[1], equal_to([0, 3, 6]), label='p1')
- assert_that(partitions[2], equal_to([1, 4, 7]), label='p2')
- assert_that(partitions[3], equal_to([2, 5, 8]), label='p3')
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create([0, 1, 2, 3, 4, 5, 6, 7, 8])
+ # Attempt nominal partition operation.
+ partitions = pcoll | 'Part 1' >> beam.Partition(SomePartitionFn(), 4, 1)
+ assert_that(partitions[0], equal_to([]))
+ assert_that(partitions[1], equal_to([0, 3, 6]), label='p1')
+ assert_that(partitions[2], equal_to([1, 4, 7]), label='p2')
+ assert_that(partitions[3], equal_to([2, 5, 8]), label='p3')
# Check that a bad partition label will yield an error. For the
# DirectRunner, this error manifests as an exception.
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create([0, 1, 2, 3, 4, 5, 6, 7, 8])
- partitions = pcoll | 'Part 2' >> beam.Partition(SomePartitionFn(), 4, 10000)
with self.assertRaises(ValueError):
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create([0, 1, 2, 3, 4, 5, 6, 7, 8])
+ partitions = pcoll | beam.Partition(SomePartitionFn(), 4, 10000)
def test_partition_with_callable(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create([0, 1, 2, 3, 4, 5, 6, 7, 8])
- partitions = (
- pcoll | 'part' >> beam.Partition(
- lambda e, n, offset: (e % 3) + offset, 4,
- 1))
- assert_that(partitions[0], equal_to([]))
- assert_that(partitions[1], equal_to([0, 3, 6]), label='p1')
- assert_that(partitions[2], equal_to([1, 4, 7]), label='p2')
- assert_that(partitions[3], equal_to([2, 5, 8]), label='p3')
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create([0, 1, 2, 3, 4, 5, 6, 7, 8])
+ partitions = (
+ pcoll | 'part' >> beam.Partition(
+ lambda e, n, offset: (e % 3) + offset, 4,
+ 1))
+ assert_that(partitions[0], equal_to([]))
+ assert_that(partitions[1], equal_to([0, 3, 6]), label='p1')
+ assert_that(partitions[2], equal_to([1, 4, 7]), label='p2')
+ assert_that(partitions[3], equal_to([2, 5, 8]), label='p3')
def test_partition_followed_by_flatten_and_groupbykey(self):
"""Regression test for an issue with how partitions are handled."""
@@ -570,56 +543,50 @@
@attr('ValidatesRunner')
def test_flatten_pcollections(self):
- pipeline = TestPipeline()
- pcoll_1 = pipeline | 'Start 1' >> beam.Create([0, 1, 2, 3])
- pcoll_2 = pipeline | 'Start 2' >> beam.Create([4, 5, 6, 7])
- result = (pcoll_1, pcoll_2) | 'Flatten' >> beam.Flatten()
- assert_that(result, equal_to([0, 1, 2, 3, 4, 5, 6, 7]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll_1 = pipeline | 'Start 1' >> beam.Create([0, 1, 2, 3])
+ pcoll_2 = pipeline | 'Start 2' >> beam.Create([4, 5, 6, 7])
+ result = (pcoll_1, pcoll_2) | 'Flatten' >> beam.Flatten()
+ assert_that(result, equal_to([0, 1, 2, 3, 4, 5, 6, 7]))
def test_flatten_no_pcollections(self):
- pipeline = TestPipeline()
- with self.assertRaises(ValueError):
- () | 'PipelineArgMissing' >> beam.Flatten()
- result = () | 'Empty' >> beam.Flatten(pipeline=pipeline)
- assert_that(result, equal_to([]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ with self.assertRaises(ValueError):
+ () | 'PipelineArgMissing' >> beam.Flatten()
+ result = () | 'Empty' >> beam.Flatten(pipeline=pipeline)
+ assert_that(result, equal_to([]))
@attr('ValidatesRunner')
def test_flatten_one_single_pcollection(self):
- pipeline = TestPipeline()
- input = [0, 1, 2, 3]
- pcoll = pipeline | 'Input' >> beam.Create(input)
- result = (pcoll,)| 'Single Flatten' >> beam.Flatten()
- assert_that(result, equal_to(input))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ input = [0, 1, 2, 3]
+ pcoll = pipeline | 'Input' >> beam.Create(input)
+ result = (pcoll,)| 'Single Flatten' >> beam.Flatten()
+ assert_that(result, equal_to(input))
# TODO(BEAM-9002): Does not work in streaming mode on Dataflow.
@attr('ValidatesRunner', 'sickbay-streaming')
def test_flatten_same_pcollections(self):
- pipeline = TestPipeline()
- pc = pipeline | beam.Create(['a', 'b'])
- assert_that((pc, pc, pc) | beam.Flatten(), equal_to(['a', 'b'] * 3))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pc = pipeline | beam.Create(['a', 'b'])
+ assert_that((pc, pc, pc) | beam.Flatten(), equal_to(['a', 'b'] * 3))
def test_flatten_pcollections_in_iterable(self):
- pipeline = TestPipeline()
- pcoll_1 = pipeline | 'Start 1' >> beam.Create([0, 1, 2, 3])
- pcoll_2 = pipeline | 'Start 2' >> beam.Create([4, 5, 6, 7])
- result = [pcoll for pcoll in (pcoll_1, pcoll_2)] | beam.Flatten()
- assert_that(result, equal_to([0, 1, 2, 3, 4, 5, 6, 7]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll_1 = pipeline | 'Start 1' >> beam.Create([0, 1, 2, 3])
+ pcoll_2 = pipeline | 'Start 2' >> beam.Create([4, 5, 6, 7])
+ result = [pcoll for pcoll in (pcoll_1, pcoll_2)] | beam.Flatten()
+ assert_that(result, equal_to([0, 1, 2, 3, 4, 5, 6, 7]))
@attr('ValidatesRunner')
def test_flatten_a_flattened_pcollection(self):
- pipeline = TestPipeline()
- pcoll_1 = pipeline | 'Start 1' >> beam.Create([0, 1, 2, 3])
- pcoll_2 = pipeline | 'Start 2' >> beam.Create([4, 5, 6, 7])
- pcoll_3 = pipeline | 'Start 3' >> beam.Create([8, 9])
- pcoll_12 = (pcoll_1, pcoll_2) | 'Flatten' >> beam.Flatten()
- pcoll_123 = (pcoll_12, pcoll_3) | 'Flatten again' >> beam.Flatten()
- assert_that(pcoll_123, equal_to([x for x in range(10)]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll_1 = pipeline | 'Start 1' >> beam.Create([0, 1, 2, 3])
+ pcoll_2 = pipeline | 'Start 2' >> beam.Create([4, 5, 6, 7])
+ pcoll_3 = pipeline | 'Start 3' >> beam.Create([8, 9])
+ pcoll_12 = (pcoll_1, pcoll_2) | 'Flatten' >> beam.Flatten()
+ pcoll_123 = (pcoll_12, pcoll_3) | 'Flatten again' >> beam.Flatten()
+ assert_that(pcoll_123, equal_to([x for x in range(10)]))
def test_flatten_input_type_must_be_iterable(self):
# Inputs to flatten *must* be an iterable.
@@ -635,21 +602,20 @@
@attr('ValidatesRunner')
def test_flatten_multiple_pcollections_having_multiple_consumers(self):
- pipeline = TestPipeline()
- input = pipeline | 'Start' >> beam.Create(['AA', 'BBB', 'CC'])
+ with TestPipeline() as pipeline:
+ input = pipeline | 'Start' >> beam.Create(['AA', 'BBB', 'CC'])
- def split_even_odd(element):
- tag = 'even_length' if len(element) % 2 == 0 else 'odd_length'
- return pvalue.TaggedOutput(tag, element)
+ def split_even_odd(element):
+ tag = 'even_length' if len(element) % 2 == 0 else 'odd_length'
+ return pvalue.TaggedOutput(tag, element)
- even_length, odd_length = (input | beam.Map(split_even_odd)
- .with_outputs('even_length', 'odd_length'))
- merged = (even_length, odd_length) | 'Flatten' >> beam.Flatten()
+ even_length, odd_length = (input | beam.Map(split_even_odd)
+ .with_outputs('even_length', 'odd_length'))
+ merged = (even_length, odd_length) | 'Flatten' >> beam.Flatten()
- assert_that(merged, equal_to(['AA', 'BBB', 'CC']))
- assert_that(even_length, equal_to(['AA', 'CC']), label='assert:even')
- assert_that(odd_length, equal_to(['BBB']), label='assert:odd')
- pipeline.run()
+ assert_that(merged, equal_to(['AA', 'BBB', 'CC']))
+ assert_that(even_length, equal_to(['AA', 'CC']), label='assert:even')
+ assert_that(odd_length, equal_to(['BBB']), label='assert:odd')
def test_co_group_by_key_on_list(self):
pipeline = TestPipeline()
@@ -690,12 +656,10 @@
pipeline.run()
def test_group_by_key_input_must_be_kv_pairs(self):
- pipeline = TestPipeline()
- pcolls = pipeline | 'A' >> beam.Create([1, 2, 3, 4, 5])
-
with self.assertRaises(typehints.TypeCheckError) as e:
- pcolls | 'D' >> beam.GroupByKey()
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcolls = pipeline | 'A' >> beam.Create([1, 2, 3, 4, 5])
+ pcolls | 'D' >> beam.GroupByKey()
self.assertStartswith(
e.exception.args[0],
@@ -703,58 +667,52 @@
'Tuple[TypeVariable[K], TypeVariable[V]]')
def test_group_by_key_only_input_must_be_kv_pairs(self):
- pipeline = TestPipeline()
- pcolls = pipeline | 'A' >> beam.Create(['a', 'b', 'f'])
with self.assertRaises(typehints.TypeCheckError) as cm:
- pcolls | 'D' >> _GroupByKeyOnly()
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcolls = pipeline | 'A' >> beam.Create(['a', 'b', 'f'])
+ pcolls | 'D' >> _GroupByKeyOnly()
expected_error_prefix = ('Input type hint violation at D: expected '
'Tuple[TypeVariable[K], TypeVariable[V]]')
self.assertStartswith(cm.exception.args[0], expected_error_prefix)
def test_keys_and_values(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create(
- [(3, 1), (2, 1), (1, 1), (3, 2), (2, 2), (3, 3)])
- keys = pcoll.apply(beam.Keys('keys'))
- vals = pcoll.apply(beam.Values('vals'))
- assert_that(keys, equal_to([1, 2, 2, 3, 3, 3]), label='assert:keys')
- assert_that(vals, equal_to([1, 1, 1, 2, 2, 3]), label='assert:vals')
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create(
+ [(3, 1), (2, 1), (1, 1), (3, 2), (2, 2), (3, 3)])
+ keys = pcoll.apply(beam.Keys('keys'))
+ vals = pcoll.apply(beam.Values('vals'))
+ assert_that(keys, equal_to([1, 2, 2, 3, 3, 3]), label='assert:keys')
+ assert_that(vals, equal_to([1, 1, 1, 2, 2, 3]), label='assert:vals')
def test_kv_swap(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create(
- [(6, 3), (1, 2), (7, 1), (5, 2), (3, 2)])
- result = pcoll.apply(beam.KvSwap(), label='swap')
- assert_that(result, equal_to([(1, 7), (2, 1), (2, 3), (2, 5), (3, 6)]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create(
+ [(6, 3), (1, 2), (7, 1), (5, 2), (3, 2)])
+ result = pcoll.apply(beam.KvSwap(), label='swap')
+ assert_that(result, equal_to([(1, 7), (2, 1), (2, 3), (2, 5), (3, 6)]))
def test_distinct(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create(
- [6, 3, 1, 1, 9, 'pleat', 'pleat', 'kazoo', 'navel'])
- result = pcoll.apply(beam.Distinct())
- assert_that(result, equal_to([1, 3, 6, 9, 'pleat', 'kazoo', 'navel']))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create(
+ [6, 3, 1, 1, 9, 'pleat', 'pleat', 'kazoo', 'navel'])
+ result = pcoll.apply(beam.Distinct())
+ assert_that(result, equal_to([1, 3, 6, 9, 'pleat', 'kazoo', 'navel']))
def test_remove_duplicates(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create(
- [6, 3, 1, 1, 9, 'pleat', 'pleat', 'kazoo', 'navel'])
- result = pcoll.apply(beam.RemoveDuplicates())
- assert_that(result, equal_to([1, 3, 6, 9, 'pleat', 'kazoo', 'navel']))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create(
+ [6, 3, 1, 1, 9, 'pleat', 'pleat', 'kazoo', 'navel'])
+ result = pcoll.apply(beam.RemoveDuplicates())
+ assert_that(result, equal_to([1, 3, 6, 9, 'pleat', 'kazoo', 'navel']))
def test_chained_ptransforms(self):
- pipeline = TestPipeline()
- t = (beam.Map(lambda x: (x, 1))
- | beam.GroupByKey()
- | beam.Map(lambda x_ones: (x_ones[0], sum(x_ones[1]))))
- result = pipeline | 'Start' >> beam.Create(['a', 'a', 'b']) | t
- assert_that(result, equal_to([('a', 2), ('b', 1)]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ t = (beam.Map(lambda x: (x, 1))
+ | beam.GroupByKey()
+ | beam.Map(lambda x_ones: (x_ones[0], sum(x_ones[1]))))
+ result = pipeline | 'Start' >> beam.Create(['a', 'a', 'b']) | t
+ assert_that(result, equal_to([('a', 2), ('b', 1)]))
def test_apply_to_list(self):
self.assertCountEqual(
@@ -850,47 +808,43 @@
def test_chained_ptransforms(self):
"""Tests that chaining gets proper nesting."""
- pipeline = TestPipeline()
- map1 = 'Map1' >> beam.Map(lambda x: (x, 1))
- gbk = 'Gbk' >> beam.GroupByKey()
- map2 = 'Map2' >> beam.Map(lambda x_ones2: (x_ones2[0], sum(x_ones2[1])))
- t = (map1 | gbk | map2)
- result = pipeline | 'Start' >> beam.Create(['a', 'a', 'b']) | t
- self.assertTrue('Map1|Gbk|Map2/Map1' in pipeline.applied_labels)
- self.assertTrue('Map1|Gbk|Map2/Gbk' in pipeline.applied_labels)
- self.assertTrue('Map1|Gbk|Map2/Map2' in pipeline.applied_labels)
- assert_that(result, equal_to([('a', 2), ('b', 1)]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ map1 = 'Map1' >> beam.Map(lambda x: (x, 1))
+ gbk = 'Gbk' >> beam.GroupByKey()
+ map2 = 'Map2' >> beam.Map(lambda x_ones2: (x_ones2[0], sum(x_ones2[1])))
+ t = (map1 | gbk | map2)
+ result = pipeline | 'Start' >> beam.Create(['a', 'a', 'b']) | t
+ self.assertTrue('Map1|Gbk|Map2/Map1' in pipeline.applied_labels)
+ self.assertTrue('Map1|Gbk|Map2/Gbk' in pipeline.applied_labels)
+ self.assertTrue('Map1|Gbk|Map2/Map2' in pipeline.applied_labels)
+ assert_that(result, equal_to([('a', 2), ('b', 1)]))
def test_apply_custom_transform_without_label(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'PColl' >> beam.Create([1, 2, 3])
- custom = PTransformLabelsTest.CustomTransform()
- result = pipeline.apply(custom, pcoll)
- self.assertTrue('CustomTransform' in pipeline.applied_labels)
- self.assertTrue('CustomTransform/*Do*' in pipeline.applied_labels)
- assert_that(result, equal_to([2, 3, 4]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'PColl' >> beam.Create([1, 2, 3])
+ custom = PTransformLabelsTest.CustomTransform()
+ result = pipeline.apply(custom, pcoll)
+ self.assertTrue('CustomTransform' in pipeline.applied_labels)
+ self.assertTrue('CustomTransform/*Do*' in pipeline.applied_labels)
+ assert_that(result, equal_to([2, 3, 4]))
def test_apply_custom_transform_with_label(self):
- pipeline = TestPipeline()
- pcoll = pipeline | 'PColl' >> beam.Create([1, 2, 3])
- custom = PTransformLabelsTest.CustomTransform('*Custom*')
- result = pipeline.apply(custom, pcoll)
- self.assertTrue('*Custom*' in pipeline.applied_labels)
- self.assertTrue('*Custom*/*Do*' in pipeline.applied_labels)
- assert_that(result, equal_to([2, 3, 4]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'PColl' >> beam.Create([1, 2, 3])
+ custom = PTransformLabelsTest.CustomTransform('*Custom*')
+ result = pipeline.apply(custom, pcoll)
+ self.assertTrue('*Custom*' in pipeline.applied_labels)
+ self.assertTrue('*Custom*/*Do*' in pipeline.applied_labels)
+ assert_that(result, equal_to([2, 3, 4]))
def test_combine_without_label(self):
vals = [1, 2, 3, 4, 5, 6, 7]
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create(vals)
- combine = beam.CombineGlobally(sum)
- result = pcoll | combine
- self.assertTrue('CombineGlobally(sum)' in pipeline.applied_labels)
- assert_that(result, equal_to([sum(vals)]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create(vals)
+ combine = beam.CombineGlobally(sum)
+ result = pcoll | combine
+ self.assertTrue('CombineGlobally(sum)' in pipeline.applied_labels)
+ assert_that(result, equal_to([sum(vals)]))
def test_apply_ptransform_using_decorator(self):
pipeline = TestPipeline()
@@ -903,13 +857,12 @@
def test_combine_with_label(self):
vals = [1, 2, 3, 4, 5, 6, 7]
- pipeline = TestPipeline()
- pcoll = pipeline | 'Start' >> beam.Create(vals)
- combine = '*Sum*' >> beam.CombineGlobally(sum)
- result = pcoll | combine
- self.assertTrue('*Sum*' in pipeline.applied_labels)
- assert_that(result, equal_to([sum(vals)]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'Start' >> beam.Create(vals)
+ combine = '*Sum*' >> beam.CombineGlobally(sum)
+ result = pcoll | combine
+ self.assertTrue('*Sum*' in pipeline.applied_labels)
+ assert_that(result, equal_to([sum(vals)]))
def check_label(self, ptransform, expected_label):
pipeline = TestPipeline()
@@ -2226,11 +2179,10 @@
def MyTransform(pcoll):
return pcoll | beam.ParDo(lambda x: [x]).with_output_types(int)
- p = TestPipeline()
- _ = (p
- | beam.Create([1, 2])
- | MyTransform().with_output_types(int))
- p.run()
+ with TestPipeline() as p:
+ _ = (p
+ | beam.Create([1, 2])
+ | MyTransform().with_output_types(int))
def test_type_hints_arg(self):
# Tests passing type hints via the magic 'type_hints' argument name.
@@ -2241,11 +2193,10 @@
| beam.ParDo(lambda x: [x]).with_output_types(
type_hints.output_types[0][0]))
- p = TestPipeline()
- _ = (p
- | beam.Create([1, 2])
- | MyTransform('test').with_output_types(int))
- p.run()
+ with TestPipeline() as p:
+ _ = (p
+ | beam.Create([1, 2])
+ | MyTransform('test').with_output_types(int))
def _sort_lists(result):
diff --git a/sdks/python/apache_beam/transforms/stats_test.py b/sdks/python/apache_beam/transforms/stats_test.py
index 550c3f5..b7f9604 100644
--- a/sdks/python/apache_beam/transforms/stats_test.py
+++ b/sdks/python/apache_beam/transforms/stats_test.py
@@ -56,13 +56,12 @@
test_input = [random.randint(0, 1000) for _ in range(100)]
with self.assertRaises(ValueError) as e:
- pipeline = TestPipeline()
- _ = (pipeline
- | 'create'
- >> beam.Create(test_input)
- | 'get_estimate'
- >> beam.ApproximateUnique.Globally(size=sample_size))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ _ = (pipeline
+ | 'create'
+ >> beam.Create(test_input)
+ | 'get_estimate'
+ >> beam.ApproximateUnique.Globally(size=sample_size))
expected_msg = beam.ApproximateUnique._INPUT_SIZE_ERR_MSG % (sample_size)
@@ -75,12 +74,11 @@
test_input = [random.randint(0, 1000) for _ in range(100)]
with self.assertRaises(ValueError) as e:
- pipeline = TestPipeline()
- _ = (pipeline
- | 'create' >> beam.Create(test_input)
- | 'get_estimate'
- >> beam.ApproximateUnique.Globally(size=sample_size))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ _ = (pipeline
+ | 'create' >> beam.Create(test_input)
+ | 'get_estimate'
+ >> beam.ApproximateUnique.Globally(size=sample_size))
expected_msg = beam.ApproximateUnique._INPUT_SIZE_ERR_MSG % (sample_size)
@@ -93,12 +91,11 @@
test_input = [random.randint(0, 1000) for _ in range(100)]
with self.assertRaises(ValueError) as e:
- pipeline = TestPipeline()
- _ = (pipeline
- | 'create' >> beam.Create(test_input)
- | 'get_estimate'
- >> beam.ApproximateUnique.Globally(error=est_err))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ _ = (pipeline
+ | 'create' >> beam.Create(test_input)
+ | 'get_estimate'
+ >> beam.ApproximateUnique.Globally(error=est_err))
expected_msg = beam.ApproximateUnique._INPUT_ERROR_ERR_MSG % (est_err)
@@ -111,12 +108,11 @@
test_input = [random.randint(0, 1000) for _ in range(100)]
with self.assertRaises(ValueError) as e:
- pipeline = TestPipeline()
- _ = (pipeline
- | 'create' >> beam.Create(test_input)
- | 'get_estimate'
- >> beam.ApproximateUnique.Globally(error=est_err))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ _ = (pipeline
+ | 'create' >> beam.Create(test_input)
+ | 'get_estimate'
+ >> beam.ApproximateUnique.Globally(error=est_err))
expected_msg = beam.ApproximateUnique._INPUT_ERROR_ERR_MSG % (est_err)
@@ -127,12 +123,11 @@
test_input = [random.randint(0, 1000) for _ in range(100)]
with self.assertRaises(ValueError) as e:
- pipeline = TestPipeline()
- _ = (pipeline
- | 'create' >> beam.Create(test_input)
- | 'get_estimate'
- >> beam.ApproximateUnique.Globally())
- pipeline.run()
+ with TestPipeline() as pipeline:
+ _ = (pipeline
+ | 'create' >> beam.Create(test_input)
+ | 'get_estimate'
+ >> beam.ApproximateUnique.Globally())
expected_msg = beam.ApproximateUnique._NO_VALUE_ERR_MSG
assert e.exception.args[0] == expected_msg
@@ -144,12 +139,11 @@
sample_size = 30
with self.assertRaises(ValueError) as e:
- pipeline = TestPipeline()
- _ = (pipeline
- | 'create' >> beam.Create(test_input)
- | 'get_estimate'
- >> beam.ApproximateUnique.Globally(size=sample_size, error=est_err))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ _ = (pipeline
+ | 'create' >> beam.Create(test_input)
+ | 'get_estimate' >> beam.ApproximateUnique.Globally(
+ size=sample_size, error=est_err))
expected_msg = beam.ApproximateUnique._MULTI_VALUE_ERR_MSG % (
sample_size, est_err)
@@ -178,18 +172,17 @@
actual_count = len(set(test_input))
- pipeline = TestPipeline()
- result = (pipeline
- | 'create' >> beam.Create(test_input)
- | 'get_estimate'
- >> beam.ApproximateUnique.Globally(size=sample_size)
- | 'compare'
- >> beam.FlatMap(lambda x: [abs(x - actual_count) * 1.0
- / actual_count <= max_err]))
+ with TestPipeline() as pipeline:
+ result = (pipeline
+ | 'create' >> beam.Create(test_input)
+ | 'get_estimate'
+ >> beam.ApproximateUnique.Globally(size=sample_size)
+ | 'compare'
+ >> beam.FlatMap(lambda x: [abs(x - actual_count) * 1.0
+ / actual_count <= max_err]))
- assert_that(result, equal_to([True]),
- label='assert:global_by_size')
- pipeline.run()
+ assert_that(result, equal_to([True]),
+ label='assert:global_by_size')
@retry(reraise=True, stop=stop_after_attempt(5))
def test_approximate_unique_global_by_sample_size_with_duplicates(self):
@@ -200,18 +193,17 @@
test_input = [10] * 50 + [20] * 50
actual_count = len(set(test_input))
- pipeline = TestPipeline()
- result = (pipeline
- | 'create' >> beam.Create(test_input)
- | 'get_estimate'
- >> beam.ApproximateUnique.Globally(size=sample_size)
- | 'compare'
- >> beam.FlatMap(lambda x: [abs(x - actual_count) * 1.0
- / actual_count <= max_err]))
+ with TestPipeline() as pipeline:
+ result = (pipeline
+ | 'create' >> beam.Create(test_input)
+ | 'get_estimate'
+ >> beam.ApproximateUnique.Globally(size=sample_size)
+ | 'compare'
+ >> beam.FlatMap(lambda x: [abs(x - actual_count) * 1.0
+ / actual_count <= max_err]))
- assert_that(result, equal_to([True]),
- label='assert:global_by_size_with_duplicates')
- pipeline.run()
+ assert_that(result, equal_to([True]),
+ label='assert:global_by_size_with_duplicates')
@retry(reraise=True, stop=stop_after_attempt(5))
def test_approximate_unique_global_by_sample_size_with_small_population(self):
@@ -223,15 +215,14 @@
221, 829, 965, 729, 35, 33, 115, 894, 827, 364]
actual_count = len(set(test_input))
- pipeline = TestPipeline()
- result = (pipeline
- | 'create' >> beam.Create(test_input)
- | 'get_estimate'
- >> beam.ApproximateUnique.Globally(size=sample_size))
+ with TestPipeline() as pipeline:
+ result = (pipeline
+ | 'create' >> beam.Create(test_input)
+ | 'get_estimate'
+ >> beam.ApproximateUnique.Globally(size=sample_size))
- assert_that(result, equal_to([actual_count]),
- label='assert:global_by_sample_size_with_small_population')
- pipeline.run()
+ assert_that(result, equal_to([actual_count]),
+ label='assert:global_by_sample_size_with_small_population')
@unittest.skip('Skip because hash function is not good enough. '
'TODO: BEAM-7654')
@@ -243,17 +234,16 @@
973, 386, 506, 546, 991, 450, 226, 889, 514, 693]
actual_count = len(set(test_input))
- pipeline = TestPipeline()
- result = (pipeline
- | 'create' >> beam.Create(test_input)
- | 'get_estimate'
- >> beam.ApproximateUnique.Globally(error=est_err)
- | 'compare'
- >> beam.FlatMap(lambda x: [abs(x - actual_count) * 1.0
- / actual_count <= est_err]))
+ with TestPipeline() as pipeline:
+ result = (pipeline
+ | 'create' >> beam.Create(test_input)
+ | 'get_estimate'
+ >> beam.ApproximateUnique.Globally(error=est_err)
+ | 'compare'
+ >> beam.FlatMap(lambda x: [abs(x - actual_count) * 1.0
+ / actual_count <= est_err]))
- assert_that(result, equal_to([True]), label='assert:global_by_error')
- pipeline.run()
+ assert_that(result, equal_to([True]), label='assert:global_by_error')
@retry(reraise=True, stop=stop_after_attempt(5))
def test_approximate_unique_global_by_error_with_small_population(self):
@@ -266,15 +256,14 @@
756, 755, 839, 79, 393]
actual_count = len(set(test_input))
- pipeline = TestPipeline()
- result = (pipeline
- | 'create' >> beam.Create(test_input)
- | 'get_estimate'
- >> beam.ApproximateUnique.Globally(error=est_err))
+ with TestPipeline() as pipeline:
+ result = (pipeline
+ | 'create' >> beam.Create(test_input)
+ | 'get_estimate'
+ >> beam.ApproximateUnique.Globally(error=est_err))
- assert_that(result, equal_to([actual_count]),
- label='assert:global_by_error_with_small_population')
- pipeline.run()
+ assert_that(result, equal_to([actual_count]),
+ label='assert:global_by_error_with_small_population')
@retry(reraise=True, stop=stop_after_attempt(5))
def test_approximate_unique_perkey_by_size(self):
@@ -292,20 +281,19 @@
for (x, y) in test_input:
actual_count_dict[x].add(y)
- pipeline = TestPipeline()
- result = (pipeline
- | 'create' >> beam.Create(test_input)
- | 'get_estimate'
- >> beam.ApproximateUnique.PerKey(size=sample_size)
- | 'compare'
- >> beam.FlatMap(lambda x: [abs(x[1]
- - len(actual_count_dict[x[0]]))
- * 1.0 / len(actual_count_dict[x[0]])
- <= max_err]))
+ with TestPipeline() as pipeline:
+ result = (pipeline
+ | 'create' >> beam.Create(test_input)
+ | 'get_estimate'
+ >> beam.ApproximateUnique.PerKey(size=sample_size)
+ | 'compare'
+ >> beam.FlatMap(lambda x: [abs(x[1]
+ - len(actual_count_dict[x[0]]))
+ * 1.0 / len(actual_count_dict[x[0]])
+ <= max_err]))
- assert_that(result, equal_to([True] * len(actual_count_dict)),
- label='assert:perkey_by_size')
- pipeline.run()
+ assert_that(result, equal_to([True] * len(actual_count_dict)),
+ label='assert:perkey_by_size')
@retry(reraise=True, stop=stop_after_attempt(5))
def test_approximate_unique_perkey_by_error(self):
@@ -318,20 +306,19 @@
for (x, y) in test_input:
actual_count_dict[x].add(y)
- pipeline = TestPipeline()
- result = (pipeline
- | 'create' >> beam.Create(test_input)
- | 'get_estimate'
- >> beam.ApproximateUnique.PerKey(error=est_err)
- | 'compare'
- >> beam.FlatMap(lambda x: [abs(x[1]
- - len(actual_count_dict[x[0]]))
- * 1.0 / len(actual_count_dict[x[0]])
- <= est_err]))
+ with TestPipeline() as pipeline:
+ result = (pipeline
+ | 'create' >> beam.Create(test_input)
+ | 'get_estimate'
+ >> beam.ApproximateUnique.PerKey(error=est_err)
+ | 'compare'
+ >> beam.FlatMap(lambda x: [abs(x[1]
+ - len(actual_count_dict[x[0]]))
+ * 1.0 / len(actual_count_dict[x[0]])
+ <= est_err]))
- assert_that(result, equal_to([True] * len(actual_count_dict)),
- label='assert:perkey_by_error')
- pipeline.run()
+ assert_that(result, equal_to([True] * len(actual_count_dict)),
+ label='assert:perkey_by_error')
@retry(reraise=True, stop=stop_after_attempt(5))
def test_approximate_unique_globally_by_error_with_skewed_data(self):
@@ -341,18 +328,17 @@
6, 55, 1, 13, 90, 4, 18, 52, 33, 0, 77, 21, 26, 5, 18]
actual_count = len(set(test_input))
- pipeline = TestPipeline()
- result = (pipeline
- | 'create' >> beam.Create(test_input)
- | 'get_estimate'
- >> beam.ApproximateUnique.Globally(error=est_err)
- | 'compare'
- >> beam.FlatMap(lambda x: [abs(x - actual_count) * 1.0
- / actual_count <= est_err]))
+ with TestPipeline() as pipeline:
+ result = (pipeline
+ | 'create' >> beam.Create(test_input)
+ | 'get_estimate'
+ >> beam.ApproximateUnique.Globally(error=est_err)
+ | 'compare'
+ >> beam.FlatMap(lambda x: [abs(x - actual_count) * 1.0
+ / actual_count <= est_err]))
- assert_that(result, equal_to([True]),
- label='assert:globally_by_error_with_skewed_data')
- pipeline.run()
+ assert_that(result, equal_to([True]),
+ label='assert:globally_by_error_with_skewed_data')
class ApproximateQuantilesTest(unittest.TestCase):
diff --git a/sdks/python/apache_beam/transforms/transforms_keyword_only_args_test_py3.py b/sdks/python/apache_beam/transforms/transforms_keyword_only_args_test_py3.py
index 661d6ac..b220373 100644
--- a/sdks/python/apache_beam/transforms/transforms_keyword_only_args_test_py3.py
+++ b/sdks/python/apache_beam/transforms/transforms_keyword_only_args_test_py3.py
@@ -36,109 +36,106 @@
_multiprocess_can_split_ = True
def test_side_input_keyword_only_args(self):
- pipeline = TestPipeline()
+ with TestPipeline() as pipeline:
- def sort_with_side_inputs(x, *s, reverse=False):
- for y in s:
- yield sorted([x] + y, reverse=reverse)
+ def sort_with_side_inputs(x, *s, reverse=False):
+ for y in s:
+ yield sorted([x] + y, reverse=reverse)
- def sort_with_side_inputs_without_default_values(x, *s, reverse):
- for y in s:
- yield sorted([x] + y, reverse=reverse)
+ def sort_with_side_inputs_without_default_values(x, *s, reverse):
+ for y in s:
+ yield sorted([x] + y, reverse=reverse)
- pcol = pipeline | 'start' >> beam.Create([1, 2])
- side = pipeline | 'side' >> beam.Create([3, 4]) # 2 values in side input.
- result1 = pcol | 'compute1' >> beam.FlatMap(
- sort_with_side_inputs,
- beam.pvalue.AsList(side), reverse=True)
- assert_that(result1, equal_to([[4, 3, 1], [4, 3, 2]]), label='assert1')
+ pcol = pipeline | 'start' >> beam.Create([1, 2])
+ side = pipeline | 'side' >> beam.Create([3, 4]) # 2 values in side input.
+ result1 = pcol | 'compute1' >> beam.FlatMap(
+ sort_with_side_inputs,
+ beam.pvalue.AsList(side), reverse=True)
+ assert_that(result1, equal_to([[4, 3, 1], [4, 3, 2]]), label='assert1')
- result2 = pcol | 'compute2' >> beam.FlatMap(
- sort_with_side_inputs,
- beam.pvalue.AsList(side))
- assert_that(result2, equal_to([[1, 3, 4], [2, 3, 4]]), label='assert2')
+ result2 = pcol | 'compute2' >> beam.FlatMap(
+ sort_with_side_inputs,
+ beam.pvalue.AsList(side))
+ assert_that(result2, equal_to([[1, 3, 4], [2, 3, 4]]), label='assert2')
- result3 = pcol | 'compute3' >> beam.FlatMap(
- sort_with_side_inputs)
- assert_that(result3, equal_to([]), label='assert3')
+ result3 = pcol | 'compute3' >> beam.FlatMap(
+ sort_with_side_inputs)
+ assert_that(result3, equal_to([]), label='assert3')
- result4 = pcol | 'compute4' >> beam.FlatMap(
- sort_with_side_inputs, reverse=True)
- assert_that(result4, equal_to([]), label='assert4')
+ result4 = pcol | 'compute4' >> beam.FlatMap(
+ sort_with_side_inputs, reverse=True)
+ assert_that(result4, equal_to([]), label='assert4')
- result5 = pcol | 'compute5' >> beam.FlatMap(
- sort_with_side_inputs_without_default_values,
- beam.pvalue.AsList(side), reverse=True)
- assert_that(result5, equal_to([[4, 3, 1], [4, 3, 2]]), label='assert5')
+ result5 = pcol | 'compute5' >> beam.FlatMap(
+ sort_with_side_inputs_without_default_values,
+ beam.pvalue.AsList(side), reverse=True)
+ assert_that(result5, equal_to([[4, 3, 1], [4, 3, 2]]), label='assert5')
- result6 = pcol | 'compute6' >> beam.FlatMap(
- sort_with_side_inputs_without_default_values,
- beam.pvalue.AsList(side), reverse=False)
- assert_that(result6, equal_to([[1, 3, 4], [2, 3, 4]]), label='assert6')
+ result6 = pcol | 'compute6' >> beam.FlatMap(
+ sort_with_side_inputs_without_default_values,
+ beam.pvalue.AsList(side), reverse=False)
+ assert_that(result6, equal_to([[1, 3, 4], [2, 3, 4]]), label='assert6')
- result7 = pcol | 'compute7' >> beam.FlatMap(
- sort_with_side_inputs_without_default_values, reverse=False)
- assert_that(result7, equal_to([]), label='assert7')
+ result7 = pcol | 'compute7' >> beam.FlatMap(
+ sort_with_side_inputs_without_default_values, reverse=False)
+ assert_that(result7, equal_to([]), label='assert7')
- result8 = pcol | 'compute8' >> beam.FlatMap(
- sort_with_side_inputs_without_default_values, reverse=True)
- assert_that(result8, equal_to([]), label='assert8')
+ result8 = pcol | 'compute8' >> beam.FlatMap(
+ sort_with_side_inputs_without_default_values, reverse=True)
+ assert_that(result8, equal_to([]), label='assert8')
- pipeline.run()
def test_combine_keyword_only_args(self):
- pipeline = TestPipeline()
+ with TestPipeline() as pipeline:
- def bounded_sum(values, *s, bound=500):
- return min(sum(values) + sum(s), bound)
+ def bounded_sum(values, *s, bound=500):
+ return min(sum(values) + sum(s), bound)
- def bounded_sum_without_default_values(values, *s, bound):
- return min(sum(values) + sum(s), bound)
+ def bounded_sum_without_default_values(values, *s, bound):
+ return min(sum(values) + sum(s), bound)
- pcoll = pipeline | 'start' >> beam.Create([6, 3, 1])
- result1 = pcoll | 'sum1' >> beam.CombineGlobally(bounded_sum, 5, 8,
- bound=20)
- result2 = pcoll | 'sum2' >> beam.CombineGlobally(bounded_sum, 0, 0)
- result3 = pcoll | 'sum3' >> beam.CombineGlobally(bounded_sum)
- result4 = pcoll | 'sum4' >> beam.CombineGlobally(bounded_sum, bound=5)
- result5 = pcoll | 'sum5' >> beam.CombineGlobally(
- bounded_sum_without_default_values, 5, 8, bound=20)
- result6 = pcoll | 'sum6' >> beam.CombineGlobally(
- bounded_sum_without_default_values, 0, 0, bound=500)
- result7 = pcoll | 'sum7' >> beam.CombineGlobally(
- bounded_sum_without_default_values, bound=500)
- result8 = pcoll | 'sum8' >> beam.CombineGlobally(
- bounded_sum_without_default_values, bound=5)
+ pcoll = pipeline | 'start' >> beam.Create([6, 3, 1])
+ result1 = pcoll | 'sum1' >> beam.CombineGlobally(bounded_sum, 5, 8,
+ bound=20)
+ result2 = pcoll | 'sum2' >> beam.CombineGlobally(bounded_sum, 0, 0)
+ result3 = pcoll | 'sum3' >> beam.CombineGlobally(bounded_sum)
+ result4 = pcoll | 'sum4' >> beam.CombineGlobally(bounded_sum, bound=5)
+ result5 = pcoll | 'sum5' >> beam.CombineGlobally(
+ bounded_sum_without_default_values, 5, 8, bound=20)
+ result6 = pcoll | 'sum6' >> beam.CombineGlobally(
+ bounded_sum_without_default_values, 0, 0, bound=500)
+ result7 = pcoll | 'sum7' >> beam.CombineGlobally(
+ bounded_sum_without_default_values, bound=500)
+ result8 = pcoll | 'sum8' >> beam.CombineGlobally(
+ bounded_sum_without_default_values, bound=5)
- assert_that(result1, equal_to([20]), label='assert1')
- assert_that(result2, equal_to([10]), label='assert2')
- assert_that(result3, equal_to([10]), label='assert3')
- assert_that(result4, equal_to([5]), label='assert4')
- assert_that(result5, equal_to([20]), label='assert5')
- assert_that(result6, equal_to([10]), label='assert6')
- assert_that(result7, equal_to([10]), label='assert7')
- assert_that(result8, equal_to([5]), label='assert8')
+ assert_that(result1, equal_to([20]), label='assert1')
+ assert_that(result2, equal_to([10]), label='assert2')
+ assert_that(result3, equal_to([10]), label='assert3')
+ assert_that(result4, equal_to([5]), label='assert4')
+ assert_that(result5, equal_to([20]), label='assert5')
+ assert_that(result6, equal_to([10]), label='assert6')
+ assert_that(result7, equal_to([10]), label='assert7')
+ assert_that(result8, equal_to([5]), label='assert8')
- pipeline.run()
def test_do_fn_keyword_only_args(self):
- pipeline = TestPipeline()
+ with TestPipeline() as pipeline:
- class MyDoFn(beam.DoFn):
- def process(self, element, *s, bound=500):
- return [min(sum(s) + element, bound)]
+ class MyDoFn(beam.DoFn):
+ def process(self, element, *s, bound=500):
+ return [min(sum(s) + element, bound)]
- pcoll = pipeline | 'start' >> beam.Create([6, 3, 1])
- result1 = pcoll | 'sum1' >> beam.ParDo(MyDoFn(), 5, 8, bound=15)
- result2 = pcoll | 'sum2' >> beam.ParDo(MyDoFn(), 5, 8)
- result3 = pcoll | 'sum3' >> beam.ParDo(MyDoFn())
- result4 = pcoll | 'sum4' >> beam.ParDo(MyDoFn(), bound=5)
+ pcoll = pipeline | 'start' >> beam.Create([6, 3, 1])
+ result1 = pcoll | 'sum1' >> beam.ParDo(MyDoFn(), 5, 8, bound=15)
+ result2 = pcoll | 'sum2' >> beam.ParDo(MyDoFn(), 5, 8)
+ result3 = pcoll | 'sum3' >> beam.ParDo(MyDoFn())
+ result4 = pcoll | 'sum4' >> beam.ParDo(MyDoFn(), bound=5)
- assert_that(result1, equal_to([15, 15, 14]), label='assert1')
- assert_that(result2, equal_to([19, 16, 14]), label='assert2')
- assert_that(result3, equal_to([6, 3, 1]), label='assert3')
- assert_that(result4, equal_to([5, 3, 1]), label='assert4')
- pipeline.run()
+ assert_that(result1, equal_to([15, 15, 14]), label='assert1')
+ assert_that(result2, equal_to([19, 16, 14]), label='assert2')
+ assert_that(result3, equal_to([6, 3, 1]), label='assert3')
+ assert_that(result4, equal_to([5, 3, 1]), label='assert4')
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py
index 9c79551..30aa88c 100644
--- a/sdks/python/apache_beam/transforms/userstate_test.py
+++ b/sdks/python/apache_beam/transforms/userstate_test.py
@@ -518,19 +518,16 @@
aggregated_value += saved_value
yield aggregated_value
- p = TestPipeline()
- values = p | beam.Create([('key', 1),
- ('key', 2),
- ('key', 3),
- ('key', 4),
- ('key', 3)], reshuffle=False)
- actual_values = (values
- | beam.ParDo(SetStatefulDoFn()))
+ with TestPipeline() as p:
+ values = p | beam.Create([('key', 1),
+ ('key', 2),
+ ('key', 3),
+ ('key', 4),
+ ('key', 3)], reshuffle=False)
+ actual_values = (values
+ | beam.ParDo(SetStatefulDoFn()))
+ assert_that(actual_values, equal_to([1, 3, 6, 10, 10]))
- assert_that(actual_values, equal_to([1, 3, 6, 10, 10]))
-
- result = p.run()
- result.wait_until_finish()
def test_stateful_set_state_clean_portably(self):
@@ -557,21 +554,19 @@
def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
yield sorted(set_state.read())
- p = TestPipeline()
- values = p | beam.Create([('key', 1),
- ('key', 2),
- ('key', 3),
- ('key', 4),
- ('key', 5)])
- actual_values = (values
- | beam.Map(lambda t: window.TimestampedValue(t, 1))
- | beam.WindowInto(window.FixedWindows(1))
- | beam.ParDo(SetStateClearingStatefulDoFn()))
+ with TestPipeline() as p:
+ values = p | beam.Create([('key', 1),
+ ('key', 2),
+ ('key', 3),
+ ('key', 4),
+ ('key', 5)])
+ actual_values = (values
+ | beam.Map(lambda t: window.TimestampedValue(t, 1))
+ | beam.WindowInto(window.FixedWindows(1))
+ | beam.ParDo(SetStateClearingStatefulDoFn()))
- assert_that(actual_values, equal_to([[100]]))
+ assert_that(actual_values, equal_to([[100]]))
- result = p.run()
- result.wait_until_finish()
def test_stateful_dofn_nonkeyed_input(self):
p = TestPipeline()
diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py
index d7290ce..fc32874 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -284,23 +284,22 @@
yield WindowedValue(
element, expected_timestamp, [expected_window])
- pipeline = TestPipeline()
- data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
- expected_windows = [
- TestWindowedValue(kv, expected_timestamp, [expected_window])
- for kv in data]
- before_identity = (pipeline
- | 'start' >> beam.Create(data)
- | 'add_windows' >> beam.ParDo(AddWindowDoFn()))
- assert_that(before_identity, equal_to(expected_windows),
- label='before_identity', reify_windows=True)
- after_identity = (before_identity
- | 'window' >> beam.WindowInto(
- beam.transforms.util._IdentityWindowFn(
- coders.IntervalWindowCoder())))
- assert_that(after_identity, equal_to(expected_windows),
- label='after_identity', reify_windows=True)
- pipeline.run()
+ with TestPipeline() as pipeline:
+ data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+ expected_windows = [
+ TestWindowedValue(kv, expected_timestamp, [expected_window])
+ for kv in data]
+ before_identity = (pipeline
+ | 'start' >> beam.Create(data)
+ | 'add_windows' >> beam.ParDo(AddWindowDoFn()))
+ assert_that(before_identity, equal_to(expected_windows),
+ label='before_identity', reify_windows=True)
+ after_identity = (before_identity
+ | 'window' >> beam.WindowInto(
+ beam.transforms.util._IdentityWindowFn(
+ coders.IntervalWindowCoder())))
+ assert_that(after_identity, equal_to(expected_windows),
+ label='after_identity', reify_windows=True)
def test_no_window_context_fails(self):
expected_timestamp = timestamp.Timestamp(5)
@@ -311,40 +310,38 @@
def process(self, element):
yield window.TimestampedValue(element, expected_timestamp)
- pipeline = TestPipeline()
- data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
- expected_windows = [
- TestWindowedValue(kv, expected_timestamp, [expected_window])
- for kv in data]
- before_identity = (pipeline
- | 'start' >> beam.Create(data)
- | 'add_timestamps' >> beam.ParDo(AddTimestampDoFn()))
- assert_that(before_identity, equal_to(expected_windows),
- label='before_identity', reify_windows=True)
- after_identity = (before_identity
- | 'window' >> beam.WindowInto(
- beam.transforms.util._IdentityWindowFn(
- coders.GlobalWindowCoder()))
- # This DoFn will return TimestampedValues, making
- # WindowFn.AssignContext passed to IdentityWindowFn
- # contain a window of None. IdentityWindowFn should
- # raise an exception.
- | 'add_timestamps2' >> beam.ParDo(AddTimestampDoFn()))
- assert_that(after_identity, equal_to(expected_windows),
- label='after_identity', reify_windows=True)
with self.assertRaisesRegex(ValueError, r'window.*None.*add_timestamps2'):
- pipeline.run()
+ with TestPipeline() as pipeline:
+ data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+ expected_windows = [
+ TestWindowedValue(kv, expected_timestamp, [expected_window])
+ for kv in data]
+ before_identity = (pipeline
+ | 'start' >> beam.Create(data)
+ | 'add_timestamps' >> beam.ParDo(AddTimestampDoFn()))
+ assert_that(before_identity, equal_to(expected_windows),
+ label='before_identity', reify_windows=True)
+ after_identity = (before_identity
+ | 'window' >> beam.WindowInto(
+ beam.transforms.util._IdentityWindowFn(
+ coders.GlobalWindowCoder()))
+ # This DoFn will return TimestampedValues, making
+ # WindowFn.AssignContext passed to IdentityWindowFn
+ # contain a window of None. IdentityWindowFn should
+ # raise an exception.
+ | 'add_timestamps2' >> beam.ParDo(AddTimestampDoFn()))
+ assert_that(after_identity, equal_to(expected_windows),
+ label='after_identity', reify_windows=True)
class ReshuffleTest(unittest.TestCase):
def test_reshuffle_contents_unchanged(self):
- pipeline = TestPipeline()
- data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
- result = (pipeline
- | beam.Create(data)
- | beam.Reshuffle())
- assert_that(result, equal_to(data))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
+ result = (pipeline
+ | beam.Create(data)
+ | beam.Reshuffle())
+ assert_that(result, equal_to(data))
def test_reshuffle_after_gbk_contents_unchanged(self):
pipeline = TestPipeline()
@@ -362,74 +359,72 @@
pipeline.run()
def test_reshuffle_timestamps_unchanged(self):
- pipeline = TestPipeline()
- timestamp = 5
- data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
- expected_result = [TestWindowedValue(v, timestamp, [GlobalWindow()])
- for v in data]
- before_reshuffle = (pipeline
- | 'start' >> beam.Create(data)
- | 'add_timestamp' >> beam.Map(
- lambda v: beam.window.TimestampedValue(v,
- timestamp)))
- assert_that(before_reshuffle, equal_to(expected_result),
- label='before_reshuffle', reify_windows=True)
- after_reshuffle = before_reshuffle | beam.Reshuffle()
- assert_that(after_reshuffle, equal_to(expected_result),
- label='after_reshuffle', reify_windows=True)
- pipeline.run()
+ with TestPipeline() as pipeline:
+ timestamp = 5
+ data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
+ expected_result = [TestWindowedValue(v, timestamp, [GlobalWindow()])
+ for v in data]
+ before_reshuffle = (pipeline
+ | 'start' >> beam.Create(data)
+ | 'add_timestamp' >> beam.Map(
+ lambda v: beam.window.TimestampedValue(
+ v, timestamp)))
+ assert_that(before_reshuffle, equal_to(expected_result),
+ label='before_reshuffle', reify_windows=True)
+ after_reshuffle = before_reshuffle | beam.Reshuffle()
+ assert_that(after_reshuffle, equal_to(expected_result),
+ label='after_reshuffle', reify_windows=True)
def test_reshuffle_windows_unchanged(self):
- pipeline = TestPipeline()
- data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
- expected_data = [TestWindowedValue(v, t - .001, [w]) for (v, t, w) in [
- ((1, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)),
- ((2, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)),
- ((3, [1]), 3.0, IntervalWindow(1.0, 3.0)),
- ((1, [4]), 6.0, IntervalWindow(4.0, 6.0))]]
- before_reshuffle = (pipeline
- | 'start' >> beam.Create(data)
- | 'add_timestamp' >> beam.Map(
- lambda v: beam.window.TimestampedValue(v, v[1]))
- | 'window' >> beam.WindowInto(Sessions(gap_size=2))
- | 'group_by_key' >> beam.GroupByKey())
- assert_that(before_reshuffle, equal_to(expected_data),
- label='before_reshuffle', reify_windows=True)
- after_reshuffle = before_reshuffle | beam.Reshuffle()
- assert_that(after_reshuffle, equal_to(expected_data),
- label='after reshuffle', reify_windows=True)
- pipeline.run()
+ with TestPipeline() as pipeline:
+ data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+ expected_data = [TestWindowedValue(v, t - .001, [w]) for (v, t, w) in [
+ ((1, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)),
+ ((2, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)),
+ ((3, [1]), 3.0, IntervalWindow(1.0, 3.0)),
+ ((1, [4]), 6.0, IntervalWindow(4.0, 6.0))]]
+ before_reshuffle = (pipeline
+ | 'start' >> beam.Create(data)
+ | 'add_timestamp' >> beam.Map(
+ lambda v: beam.window.TimestampedValue(v, v[1]))
+ | 'window' >> beam.WindowInto(Sessions(gap_size=2))
+ | 'group_by_key' >> beam.GroupByKey())
+ assert_that(before_reshuffle, equal_to(expected_data),
+ label='before_reshuffle', reify_windows=True)
+ after_reshuffle = before_reshuffle | beam.Reshuffle()
+ assert_that(after_reshuffle, equal_to(expected_data),
+ label='after reshuffle', reify_windows=True)
def test_reshuffle_window_fn_preserved(self):
- pipeline = TestPipeline()
- data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
- expected_windows = [TestWindowedValue(v, t, [w]) for (v, t, w) in [
- ((1, 1), 1.0, IntervalWindow(1.0, 3.0)),
- ((2, 1), 1.0, IntervalWindow(1.0, 3.0)),
- ((3, 1), 1.0, IntervalWindow(1.0, 3.0)),
- ((1, 2), 2.0, IntervalWindow(2.0, 4.0)),
- ((2, 2), 2.0, IntervalWindow(2.0, 4.0)),
- ((1, 4), 4.0, IntervalWindow(4.0, 6.0))]]
- expected_merged_windows = [
- TestWindowedValue(v, t - .001, [w]) for (v, t, w) in [
- ((1, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)),
- ((2, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)),
- ((3, [1]), 3.0, IntervalWindow(1.0, 3.0)),
- ((1, [4]), 6.0, IntervalWindow(4.0, 6.0))]]
- before_reshuffle = (pipeline
- | 'start' >> beam.Create(data)
- | 'add_timestamp' >> beam.Map(
- lambda v: TimestampedValue(v, v[1]))
- | 'window' >> beam.WindowInto(Sessions(gap_size=2)))
- assert_that(before_reshuffle, equal_to(expected_windows),
- label='before_reshuffle', reify_windows=True)
- after_reshuffle = before_reshuffle | beam.Reshuffle()
- assert_that(after_reshuffle, equal_to(expected_windows),
- label='after_reshuffle', reify_windows=True)
- after_group = after_reshuffle | beam.GroupByKey()
- assert_that(after_group, equal_to(expected_merged_windows),
- label='after_group', reify_windows=True)
- pipeline.run()
+ any_order = contains_in_any_order
+ with TestPipeline() as pipeline:
+ data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+ expected_windows = [TestWindowedValue(v, t, [w]) for (v, t, w) in [
+ ((1, 1), 1.0, IntervalWindow(1.0, 3.0)),
+ ((2, 1), 1.0, IntervalWindow(1.0, 3.0)),
+ ((3, 1), 1.0, IntervalWindow(1.0, 3.0)),
+ ((1, 2), 2.0, IntervalWindow(2.0, 4.0)),
+ ((2, 2), 2.0, IntervalWindow(2.0, 4.0)),
+ ((1, 4), 4.0, IntervalWindow(4.0, 6.0))]]
+ expected_merged_windows = [
+ TestWindowedValue(v, t - .001, [w]) for (v, t, w) in [
+ ((1, any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)),
+ ((2, any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)),
+ ((3, [1]), 3.0, IntervalWindow(1.0, 3.0)),
+ ((1, [4]), 6.0, IntervalWindow(4.0, 6.0))]]
+ before_reshuffle = (pipeline
+ | 'start' >> beam.Create(data)
+ | 'add_timestamp' >> beam.Map(
+ lambda v: TimestampedValue(v, v[1]))
+ | 'window' >> beam.WindowInto(Sessions(gap_size=2)))
+ assert_that(before_reshuffle, equal_to(expected_windows),
+ label='before_reshuffle', reify_windows=True)
+ after_reshuffle = before_reshuffle | beam.Reshuffle()
+ assert_that(after_reshuffle, equal_to(expected_windows),
+ label='after_reshuffle', reify_windows=True)
+ after_group = after_reshuffle | beam.GroupByKey()
+ assert_that(after_group, equal_to(expected_merged_windows),
+ label='after_group', reify_windows=True)
def test_reshuffle_global_window(self):
pipeline = TestPipeline()
@@ -584,16 +579,16 @@
return data
def test_in_global_window(self):
- pipeline = TestPipeline()
- collection = pipeline \
- | beam.Create(GroupIntoBatchesTest._create_test_data()) \
- | util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE)
- num_batches = collection | beam.combiners.Count.Globally()
- assert_that(num_batches,
- equal_to([int(math.ceil(GroupIntoBatchesTest.NUM_ELEMENTS /
- GroupIntoBatchesTest.BATCH_SIZE))]))
- pipeline.run()
+ with TestPipeline() as pipeline:
+ collection = pipeline \
+ | beam.Create(GroupIntoBatchesTest._create_test_data()) \
+ | util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE)
+ num_batches = collection | beam.combiners.Count.Globally()
+ assert_that(num_batches,
+ equal_to([int(math.ceil(GroupIntoBatchesTest.NUM_ELEMENTS /
+ GroupIntoBatchesTest.BATCH_SIZE))]))
+ @unittest.skip('BEAM-8748')
def test_in_streaming_mode(self):
timestamp_interval = 1
offset = itertools.count(0)
@@ -609,26 +604,23 @@
.advance_watermark_to(start_time +
GroupIntoBatchesTest.NUM_ELEMENTS)
.advance_watermark_to_infinity())
- pipeline = TestPipeline(options=StandardOptions(streaming=True))
- # window duration is 6 and batch size is 5, so output batch size should be
- # 5 (flush because of batchSize reached)
- expected_0 = 5
- # there is only one element left in the window so batch size should be 1
- # (flush because of end of window reached)
- expected_1 = 1
- # collection is 10 elements, there is only 4 left, so batch size should be
- # 4 (flush because end of collection reached)
- expected_2 = 4
+ with TestPipeline(options=StandardOptions(streaming=True)) as pipeline:
+ # window duration is 6 and batch size is 5, so output batch size
+ # should be 5 (flush because of batchSize reached)
+ expected_0 = 5
+ # there is only one element left in the window so batch size
+ # should be 1 (flush because of end of window reached)
+ expected_1 = 1
+ # collection is 10 elements, there is only 4 left, so batch size
+ # should be 4 (flush because end of collection reached)
+ expected_2 = 4
- collection = pipeline | test_stream \
- | WindowInto(FixedWindows(window_duration)) \
- | util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE)
- num_elements_in_batches = collection | beam.Map(len)
-
- result = pipeline.run()
- result.wait_until_finish()
- assert_that(num_elements_in_batches,
- equal_to([expected_0, expected_1, expected_2]))
+ collection = pipeline | test_stream \
+ | WindowInto(FixedWindows(window_duration)) \
+ | util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE)
+ num_elements_in_batches = collection | beam.Map(len)
+ assert_that(num_elements_in_batches,
+ equal_to([expected_0, expected_1, expected_2]))
class ToStringTest(unittest.TestCase):
diff --git a/sdks/python/apache_beam/typehints/typed_pipeline_test.py b/sdks/python/apache_beam/typehints/typed_pipeline_test.py
index fafd386..52cd4ee 100644
--- a/sdks/python/apache_beam/typehints/typed_pipeline_test.py
+++ b/sdks/python/apache_beam/typehints/typed_pipeline_test.py
@@ -137,19 +137,18 @@
self.assertEqual([1, 3], [1, 2, 3] | beam.Filter(filter_fn))
def test_partition(self):
- p = TestPipeline()
- even, odd = (p
- | beam.Create([1, 2, 3])
- | 'even_odd' >> beam.Partition(lambda e, _: e % 2, 2))
- self.assertIsNotNone(even.element_type)
- self.assertIsNotNone(odd.element_type)
- res_even = (even
- | 'id_even' >> beam.ParDo(lambda e: [e]).with_input_types(int))
- res_odd = (odd
- | 'id_odd' >> beam.ParDo(lambda e: [e]).with_input_types(int))
- assert_that(res_even, equal_to([2]), label='even_check')
- assert_that(res_odd, equal_to([1, 3]), label='odd_check')
- p.run()
+ with TestPipeline() as p:
+ even, odd = (p
+ | beam.Create([1, 2, 3])
+ | 'even_odd' >> beam.Partition(lambda e, _: e % 2, 2))
+ self.assertIsNotNone(even.element_type)
+ self.assertIsNotNone(odd.element_type)
+ res_even = (even
+ | 'IdEven' >> beam.ParDo(lambda e: [e]).with_input_types(int))
+ res_odd = (odd
+ | 'IdOdd' >> beam.ParDo(lambda e: [e]).with_input_types(int))
+ assert_that(res_even, equal_to([2]), label='even_check')
+ assert_that(res_odd, equal_to([1, 3]), label='odd_check')
def test_typed_dofn_multi_output(self):
class MyDoFn(beam.DoFn):