blob: ba9e21f85567b87b60fb1ea5be2c7ce1beafad25 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Unit tests for our libraries of combine PTransforms."""
# pytype: skip-file
import itertools
import json
import os
import random
import tempfile
import time
import unittest
from pathlib import Path
import hamcrest as hc
import pytest
import apache_beam as beam
import apache_beam.transforms.combiners as combine
from apache_beam import pvalue
from apache_beam.metrics import Metrics
from apache_beam.metrics import MetricsFilter
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.test_stream import TestStream
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.testing.util import equal_to_per_window
from apache_beam.transforms import WindowInto
from apache_beam.transforms import trigger
from apache_beam.transforms import window
from apache_beam.transforms.core import CombineGlobally
from apache_beam.transforms.core import Create
from apache_beam.transforms.core import Map
from apache_beam.transforms.display import DisplayData
from apache_beam.transforms.display_test import DisplayDataItemMatcher
from apache_beam.transforms.periodicsequence import PeriodicImpulse
from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.trigger import AfterAll
from apache_beam.transforms.trigger import AfterCount
from apache_beam.transforms.trigger import AfterWatermark
from apache_beam.transforms.window import FixedWindows
from apache_beam.transforms.window import GlobalWindows
from apache_beam.transforms.window import TimestampCombiner
from apache_beam.transforms.window import TimestampedValue
from apache_beam.typehints import TypeCheckError
from apache_beam.utils.timestamp import Timestamp
class SortedConcatWithCounters(beam.CombineFn):
"""CombineFn for incrementing three different counters:
counter, distribution, gauge,
at the same time concatenating words."""
def __init__(self):
beam.CombineFn.__init__(self)
self.word_counter = Metrics.counter(self.__class__, 'word_counter')
self.word_lengths_counter = Metrics.counter(self.__class__, 'word_lengths')
self.word_lengths_dist = Metrics.distribution(
self.__class__, 'word_len_dist')
self.last_word_len = Metrics.gauge(self.__class__, 'last_word_len')
def create_accumulator(self):
return ''
def add_input(self, acc, element):
self.word_counter.inc(1)
self.word_lengths_counter.inc(len(element))
self.word_lengths_dist.update(len(element))
self.last_word_len.set(len(element))
return acc + element
def merge_accumulators(self, accs):
return ''.join(accs)
def extract_output(self, acc):
# The sorted acc became a list of characters
# and has to be converted back to a string using join.
return ''.join(sorted(acc))
class CombineTest(unittest.TestCase):
def test_builtin_combines(self):
with TestPipeline() as pipeline:
vals = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
mean = sum(vals) / float(len(vals))
size = len(vals)
timestamp = 0
# First for global combines.
pcoll = pipeline | 'start' >> Create(vals)
result_mean = pcoll | 'mean' >> combine.Mean.Globally()
result_count = pcoll | 'count' >> combine.Count.Globally()
assert_that(result_mean, equal_to([mean]), label='assert:mean')
assert_that(result_count, equal_to([size]), label='assert:size')
# Now for global combines without default
timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
result_windowed_mean = (
windowed
| 'mean-wo-defaults' >> combine.Mean.Globally().without_defaults())
assert_that(
result_windowed_mean,
equal_to([mean]),
label='assert:mean-wo-defaults')
result_windowed_count = (
windowed
| 'count-wo-defaults' >> combine.Count.Globally().without_defaults())
assert_that(
result_windowed_count,
equal_to([size]),
label='assert:count-wo-defaults')
# Again for per-key combines.
pcoll = pipeline | 'start-perkey' >> Create([('a', x) for x in vals])
result_key_mean = pcoll | 'mean-perkey' >> combine.Mean.PerKey()
result_key_count = pcoll | 'count-perkey' >> combine.Count.PerKey()
assert_that(result_key_mean, equal_to([('a', mean)]), label='key:mean')
assert_that(result_key_count, equal_to([('a', size)]), label='key:size')
def test_top(self):
with TestPipeline() as pipeline:
timestamp = 0
# First for global combines.
pcoll = pipeline | 'start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6])
result_top = pcoll | 'top' >> combine.Top.Largest(5)
result_bot = pcoll | 'bot' >> combine.Top.Smallest(4)
assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='assert:top')
assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='assert:bot')
# Now for global combines without default
timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
result_windowed_top = windowed | 'top-wo-defaults' >> combine.Top.Largest(
5, has_defaults=False)
result_windowed_bot = (
windowed
| 'bot-wo-defaults' >> combine.Top.Smallest(4, has_defaults=False))
assert_that(
result_windowed_top,
equal_to([[9, 6, 6, 5, 3]]),
label='assert:top-wo-defaults')
assert_that(
result_windowed_bot,
equal_to([[0, 1, 1, 1]]),
label='assert:bot-wo-defaults')
# Again for per-key combines.
pcoll = pipeline | 'start-perkey' >> Create(
[('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]])
result_key_top = pcoll | 'top-perkey' >> combine.Top.LargestPerKey(5)
result_key_bot = pcoll | 'bot-perkey' >> combine.Top.SmallestPerKey(4)
assert_that(
result_key_top, equal_to([('a', [9, 6, 6, 5, 3])]), label='key:top')
assert_that(
result_key_bot, equal_to([('a', [0, 1, 1, 1])]), label='key:bot')
def test_empty_global_top(self):
with TestPipeline() as p:
assert_that(p | beam.Create([]) | combine.Top.Largest(10), equal_to([[]]))
def test_sharded_top(self):
elements = list(range(100))
random.shuffle(elements)
with TestPipeline() as pipeline:
shards = [
pipeline | 'Shard%s' % shard >> beam.Create(elements[shard::7])
for shard in range(7)
]
assert_that(
shards | beam.Flatten() | combine.Top.Largest(10),
equal_to([[99, 98, 97, 96, 95, 94, 93, 92, 91, 90]]))
def test_top_key(self):
self.assertEqual(['aa', 'bbb', 'c', 'dddd'] | combine.Top.Of(3, key=len),
[['dddd', 'bbb', 'aa']])
self.assertEqual(['aa', 'bbb', 'c', 'dddd']
| combine.Top.Of(3, key=len, reverse=True),
[['c', 'aa', 'bbb']])
self.assertEqual(['xc', 'zb', 'yd', 'wa']
| combine.Top.Largest(3, key=lambda x: x[-1]),
[['yd', 'xc', 'zb']])
self.assertEqual(['xc', 'zb', 'yd', 'wa']
| combine.Top.Smallest(3, key=lambda x: x[-1]),
[['wa', 'zb', 'xc']])
self.assertEqual([('a', x) for x in [1, 2, 3, 4, 1, 1]]
| combine.Top.LargestPerKey(3, key=lambda x: -x),
[('a', [1, 1, 1])])
self.assertEqual([('a', x) for x in [1, 2, 3, 4, 1, 1]]
| combine.Top.SmallestPerKey(3, key=lambda x: -x),
[('a', [4, 3, 2])])
def test_sharded_top_combine_fn(self):
def test_combine_fn(combine_fn, shards, expected):
accumulators = [
combine_fn.add_inputs(combine_fn.create_accumulator(), shard)
for shard in shards
]
final_accumulator = combine_fn.merge_accumulators(accumulators)
self.assertEqual(combine_fn.extract_output(final_accumulator), expected)
test_combine_fn(combine.TopCombineFn(3), [range(10), range(10)], [9, 9, 8])
test_combine_fn(
combine.TopCombineFn(5), [range(1000), range(100), range(1001)],
[1000, 999, 999, 998, 998])
def test_combine_per_key_top_display_data(self):
def individual_test_per_key_dd(combineFn):
transform = beam.CombinePerKey(combineFn)
dd = DisplayData.create_from(transform)
expected_items = [
DisplayDataItemMatcher('combine_fn', combineFn.__class__),
DisplayDataItemMatcher('n', combineFn._n),
DisplayDataItemMatcher('compare', combineFn._compare.__name__)
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
individual_test_per_key_dd(combine.Largest(5))
individual_test_per_key_dd(combine.Smallest(3))
individual_test_per_key_dd(combine.TopCombineFn(8))
individual_test_per_key_dd(combine.Largest(5))
def test_combine_sample_display_data(self):
def individual_test_per_key_dd(sampleFn, n):
trs = [sampleFn(n)]
for transform in trs:
dd = DisplayData.create_from(transform)
hc.assert_that(
dd.items,
hc.contains_inanyorder(DisplayDataItemMatcher('n', transform._n)))
individual_test_per_key_dd(combine.Sample.FixedSizePerKey, 5)
individual_test_per_key_dd(combine.Sample.FixedSizeGlobally, 5)
def test_combine_globally_display_data(self):
transform = beam.CombineGlobally(combine.Smallest(5))
dd = DisplayData.create_from(transform)
expected_items = [
DisplayDataItemMatcher('combine_fn', combine.Smallest),
DisplayDataItemMatcher('n', 5),
DisplayDataItemMatcher('compare', 'gt')
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_basic_combiners_display_data(self):
transform = beam.CombineGlobally(
combine.TupleCombineFn(max, combine.MeanCombineFn(), sum))
dd = DisplayData.create_from(transform)
expected_items = [
DisplayDataItemMatcher('combine_fn', combine.TupleCombineFn),
DisplayDataItemMatcher('combiners', "['max', 'MeanCombineFn', 'sum']"),
DisplayDataItemMatcher('merge_accumulators_batch_size', 333),
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_top_shorthands(self):
with TestPipeline() as pipeline:
pcoll = pipeline | 'start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6])
result_top = pcoll | 'top' >> beam.CombineGlobally(combine.Largest(5))
result_bot = pcoll | 'bot' >> beam.CombineGlobally(combine.Smallest(4))
assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='assert:top')
assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='assert:bot')
pcoll = pipeline | 'start-perkey' >> Create(
[('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]])
result_ktop = pcoll | 'top-perkey' >> beam.CombinePerKey(
combine.Largest(5))
result_kbot = pcoll | 'bot-perkey' >> beam.CombinePerKey(
combine.Smallest(4))
assert_that(result_ktop, equal_to([('a', [9, 6, 6, 5, 3])]), label='ktop')
assert_that(result_kbot, equal_to([('a', [0, 1, 1, 1])]), label='kbot')
def test_top_no_compact(self):
class TopCombineFnNoCompact(combine.TopCombineFn):
def compact(self, accumulator):
return accumulator
with TestPipeline() as pipeline:
pcoll = pipeline | 'Start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6])
result_top = pcoll | 'Top' >> beam.CombineGlobally(
TopCombineFnNoCompact(5, key=lambda x: x))
result_bot = pcoll | 'Bot' >> beam.CombineGlobally(
TopCombineFnNoCompact(4, reverse=True))
assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='Assert:Top')
assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='Assert:Bot')
pcoll = pipeline | 'Start-Perkey' >> Create(
[('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]])
result_ktop = pcoll | 'Top-PerKey' >> beam.CombinePerKey(
TopCombineFnNoCompact(5, key=lambda x: x))
result_kbot = pcoll | 'Bot-PerKey' >> beam.CombinePerKey(
TopCombineFnNoCompact(4, reverse=True))
assert_that(result_ktop, equal_to([('a', [9, 6, 6, 5, 3])]), label='KTop')
assert_that(result_kbot, equal_to([('a', [0, 1, 1, 1])]), label='KBot')
def test_global_sample(self):
def is_good_sample(actual):
assert len(actual) == 1
assert sorted(actual[0]) in [[1, 1, 2], [1, 2, 2]], actual
with TestPipeline() as pipeline:
timestamp = 0
pcoll = pipeline | 'start' >> Create([1, 1, 2, 2])
# Now for global combines without default
timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
for ix in range(9):
assert_that(
pcoll | 'sample-%d' % ix >> combine.Sample.FixedSizeGlobally(3),
is_good_sample,
label='check-%d' % ix)
result_windowed = (
windowed
| 'sample-wo-defaults-%d' % ix >>
combine.Sample.FixedSizeGlobally(3).without_defaults())
assert_that(
result_windowed, is_good_sample, label='check-wo-defaults-%d' % ix)
def test_per_key_sample(self):
with TestPipeline() as pipeline:
pcoll = pipeline | 'start-perkey' >> Create(
sum(([(i, 1), (i, 1), (i, 2), (i, 2)] for i in range(9)), []))
result = pcoll | 'sample' >> combine.Sample.FixedSizePerKey(3)
def matcher():
def match(actual):
for _, samples in actual:
equal_to([3])([len(samples)])
num_ones = sum(1 for x in samples if x == 1)
num_twos = sum(1 for x in samples if x == 2)
equal_to([1, 2])([num_ones, num_twos])
return match
assert_that(result, matcher())
def test_tuple_combine_fn(self):
with TestPipeline() as p:
result = (
p
| Create([('a', 100, 0.0), ('b', 10, -1), ('c', 1, 100)])
| beam.CombineGlobally(
combine.TupleCombineFn(max, combine.MeanCombineFn(),
sum)).without_defaults())
assert_that(result, equal_to([('c', 111.0 / 3, 99.0)]))
def test_tuple_combine_fn_without_defaults(self):
with TestPipeline() as p:
result = (
p
| Create([1, 1, 2, 3])
| beam.CombineGlobally(
combine.TupleCombineFn(
min, combine.MeanCombineFn(),
max).with_common_input()).without_defaults())
assert_that(result, equal_to([(1, 7.0 / 4, 3)]))
def test_empty_tuple_combine_fn(self):
with TestPipeline() as p:
result = (
p
| Create([(), (), ()])
| beam.CombineGlobally(combine.TupleCombineFn()))
assert_that(result, equal_to([()]))
def test_tuple_combine_fn_batched_merge(self):
num_combine_fns = 10
max_num_accumulators_in_memory = 30
# Maximum number of accumulator tuples in memory - 1 for the merge result.
merge_accumulators_batch_size = (
max_num_accumulators_in_memory // num_combine_fns - 1)
num_accumulator_tuples_to_merge = 20
class CountedAccumulator:
count = 0
oom = False
def __init__(self):
if CountedAccumulator.count > max_num_accumulators_in_memory:
CountedAccumulator.oom = True
else:
CountedAccumulator.count += 1
class CountedAccumulatorCombineFn(beam.CombineFn):
def create_accumulator(self):
return CountedAccumulator()
def merge_accumulators(self, accumulators):
CountedAccumulator.count += 1
for _ in accumulators:
CountedAccumulator.count -= 1
combine_fn = combine.TupleCombineFn(
*[CountedAccumulatorCombineFn() for _ in range(num_combine_fns)],
merge_accumulators_batch_size=merge_accumulators_batch_size)
combine_fn.merge_accumulators(
combine_fn.create_accumulator()
for _ in range(num_accumulator_tuples_to_merge))
assert not CountedAccumulator.oom
def test_to_list_and_to_dict1(self):
with TestPipeline() as pipeline:
the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
timestamp = 0
pcoll = pipeline | 'start' >> Create(the_list)
result = pcoll | 'to list' >> combine.ToList()
# Now for global combines without default
timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
result_windowed = (
windowed
| 'to list wo defaults' >> combine.ToList().without_defaults())
def matcher(expected):
def match(actual):
equal_to(expected[0])(actual[0])
return match
assert_that(result, matcher([the_list]))
assert_that(
result_windowed, matcher([the_list]), label='to-list-wo-defaults')
def test_to_list_and_to_dict2(self):
with TestPipeline() as pipeline:
pairs = [(1, 2), (3, 4), (5, 6)]
timestamp = 0
pcoll = pipeline | 'start-pairs' >> Create(pairs)
result = pcoll | 'to dict' >> combine.ToDict()
# Now for global combines without default
timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
result_windowed = (
windowed
| 'to dict wo defaults' >> combine.ToDict().without_defaults())
def matcher():
def match(actual):
equal_to([1])([len(actual)])
equal_to(pairs)(actual[0].items())
return match
assert_that(result, matcher())
assert_that(result_windowed, matcher(), label='to-dict-wo-defaults')
def test_to_set(self):
pipeline = TestPipeline()
the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
timestamp = 0
pcoll = pipeline | 'start' >> Create(the_list)
result = pcoll | 'to set' >> combine.ToSet()
# Now for global combines without default
timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
result_windowed = (
windowed
| 'to set wo defaults' >> combine.ToSet().without_defaults())
def matcher(expected):
def match(actual):
equal_to(expected[0])(actual[0])
return match
assert_that(result, matcher(set(the_list)))
assert_that(
result_windowed, matcher(set(the_list)), label='to-set-wo-defaults')
def test_combine_globally_with_default(self):
with TestPipeline() as p:
assert_that(p | Create([]) | CombineGlobally(sum), equal_to([0]))
def test_combine_globally_without_default(self):
with TestPipeline() as p:
result = p | Create([]) | CombineGlobally(sum).without_defaults()
assert_that(result, equal_to([]))
def test_combine_globally_with_default_side_input(self):
class SideInputCombine(PTransform):
def expand(self, pcoll):
side = pcoll | CombineGlobally(sum).as_singleton_view()
main = pcoll.pipeline | Create([None])
return main | Map(lambda _, s: s, side)
with TestPipeline() as p:
result1 = p | 'i1' >> Create([]) | 'c1' >> SideInputCombine()
result2 = p | 'i2' >> Create([1, 2, 3, 4]) | 'c2' >> SideInputCombine()
assert_that(result1, equal_to([0]), label='r1')
assert_that(result2, equal_to([10]), label='r2')
def test_hot_key_fanout(self):
with TestPipeline() as p:
result = (
p
| beam.Create(itertools.product(['hot', 'cold'], range(10)))
| beam.CombinePerKey(combine.MeanCombineFn()).with_hot_key_fanout(
lambda key: (key == 'hot') * 5))
assert_that(result, equal_to([('hot', 4.5), ('cold', 4.5)]))
def test_hot_key_fanout_sharded(self):
# Lots of elements with the same key with varying/no fanout.
with TestPipeline() as p:
elements = [(None, e) for e in range(1000)]
random.shuffle(elements)
shards = [
p | "Shard%s" % shard >> beam.Create(elements[shard::20])
for shard in range(20)
]
result = (
shards
| beam.Flatten()
| beam.CombinePerKey(combine.MeanCombineFn()).with_hot_key_fanout(
lambda key: random.randrange(0, 5)))
assert_that(result, equal_to([(None, 499.5)]))
def test_global_fanout(self):
with TestPipeline() as p:
result = (
p
| beam.Create(range(100))
| beam.CombineGlobally(combine.MeanCombineFn()).with_fanout(11))
assert_that(result, equal_to([49.5]))
def test_combining_with_accumulation_mode_and_fanout(self):
# PCollection will contain elements from 1 to 5.
elements = [i for i in range(1, 6)]
ts = TestStream().advance_watermark_to(0)
for i in elements:
ts.add_elements([i])
ts.advance_watermark_to_infinity()
options = PipelineOptions()
options.view_as(StandardOptions).streaming = True
with TestPipeline(options=options) as p:
result = (
p
| ts
| beam.WindowInto(
GlobalWindows(),
accumulation_mode=trigger.AccumulationMode.ACCUMULATING,
trigger=AfterWatermark(early=AfterAll(AfterCount(1))))
| beam.CombineGlobally(sum).without_defaults().with_fanout(2))
def has_expected_values(actual):
from hamcrest.core import assert_that as hamcrest_assert
from hamcrest.library.collection import contains
from hamcrest.library.collection import only_contains
ordered = sorted(actual)
# Early firings.
hamcrest_assert(ordered[:4], contains(1, 3, 6, 10))
# Different runners have different number of 15s, but there should
# be at least one 15.
hamcrest_assert(ordered[4:], only_contains(15))
assert_that(result, has_expected_values)
def test_combining_with_sliding_windows_and_fanout_raises_error(self):
options = PipelineOptions()
options.view_as(StandardOptions).streaming = True
with self.assertRaises(ValueError):
with TestPipeline(options=options) as p:
_ = (
p
| beam.Create([
window.TimestampedValue(0, Timestamp(seconds=1666707510)),
window.TimestampedValue(1, Timestamp(seconds=1666707511)),
window.TimestampedValue(2, Timestamp(seconds=1666707512)),
window.TimestampedValue(3, Timestamp(seconds=1666707513)),
window.TimestampedValue(5, Timestamp(seconds=1666707515)),
window.TimestampedValue(6, Timestamp(seconds=1666707516)),
window.TimestampedValue(7, Timestamp(seconds=1666707517)),
window.TimestampedValue(8, Timestamp(seconds=1666707518))
])
| beam.WindowInto(window.SlidingWindows(10, 5))
| beam.CombineGlobally(beam.combiners.ToListCombineFn()).
without_defaults().with_fanout(7))
def test_MeanCombineFn_combine(self):
with TestPipeline() as p:
input = (
p
| beam.Create([('a', 1), ('a', 1), ('a', 4), ('b', 1), ('b', 13)]))
# The mean of all values regardless of key.
global_mean = (
input
| beam.Values()
| beam.CombineGlobally(combine.MeanCombineFn()))
# The (key, mean) pairs for all keys.
mean_per_key = (input | beam.CombinePerKey(combine.MeanCombineFn()))
expected_mean_per_key = [('a', 2), ('b', 7)]
assert_that(global_mean, equal_to([4]), label='global mean')
assert_that(
mean_per_key, equal_to(expected_mean_per_key), label='mean per key')
def test_MeanCombineFn_combine_empty(self):
# For each element in a PCollection, if it is float('NaN'), then emits
# a string 'NaN', otherwise emits str(element).
with TestPipeline() as p:
input = (p | beam.Create([]))
# Compute the mean of all values in the PCollection,
# then format the mean. Since the Pcollection is empty,
# the mean is float('NaN'), and is formatted to be a string 'NaN'.
global_mean = (
input
| beam.Values()
| beam.CombineGlobally(combine.MeanCombineFn())
| beam.Map(str))
mean_per_key = (input | beam.CombinePerKey(combine.MeanCombineFn()))
# We can't compare one float('NaN') with another float('NaN'),
# but we can compare one 'nan' string with another string.
assert_that(global_mean, equal_to(['nan']), label='global mean')
assert_that(mean_per_key, equal_to([]), label='mean per key')
def test_sessions_combine(self):
with TestPipeline() as p:
input = (
p
| beam.Create([('c', 1), ('c', 9), ('c', 12), ('d', 2), ('d', 4)])
| beam.MapTuple(lambda k, v: window.TimestampedValue((k, v), v))
| beam.WindowInto(window.Sessions(4)))
global_sum = (
input
| beam.Values()
| beam.CombineGlobally(sum).without_defaults())
sum_per_key = input | beam.CombinePerKey(sum)
# The first window has 3 elements: ('c', 1), ('d', 2), ('d', 4).
# The second window has 2 elements: ('c', 9), ('c', 12).
assert_that(global_sum, equal_to([7, 21]), label='global sum')
assert_that(
sum_per_key,
equal_to([('c', 1), ('c', 21), ('d', 6)]),
label='sum per key')
def test_fixed_windows_combine(self):
with TestPipeline() as p:
input = (
p
| beam.Create([('c', 1), ('c', 2), ('c', 10), ('d', 5), ('d', 8),
('d', 9)])
| beam.MapTuple(lambda k, v: window.TimestampedValue((k, v), v))
| beam.WindowInto(window.FixedWindows(4)))
global_sum = (
input
| beam.Values()
| beam.CombineGlobally(sum).without_defaults())
sum_per_key = input | beam.CombinePerKey(sum)
# The first window has 2 elements: ('c', 1), ('c', 2).
# The second window has 1 elements: ('d', 5).
# The third window has 3 elements: ('c', 10), ('d', 8), ('d', 9).
assert_that(global_sum, equal_to([3, 5, 27]), label='global sum')
assert_that(
sum_per_key,
equal_to([('c', 3), ('c', 10), ('d', 5), ('d', 17)]),
label='sum per key')
# Test that three different kinds of metrics work with a customized
# SortedConcatWithCounters CombineFn.
def test_custormized_counters_in_combine_fn(self):
p = TestPipeline()
input = (
p
| beam.Create([('key1', 'a'), ('key1', 'ab'), ('key1', 'abc'),
('key2', 'uvxy'), ('key2', 'uvxyz')]))
# The result of concatenating all values regardless of key.
global_concat = (
input
| beam.Values()
| beam.CombineGlobally(SortedConcatWithCounters()))
# The (key, concatenated_string) pairs for all keys.
concat_per_key = (input | beam.CombinePerKey(SortedConcatWithCounters()))
# Verify the concatenated strings are correct.
expected_concat_per_key = [('key1', 'aaabbc'), ('key2', 'uuvvxxyyz')]
assert_that(
global_concat, equal_to(['aaabbcuuvvxxyyz']), label='global concat')
assert_that(
concat_per_key,
equal_to(expected_concat_per_key),
label='concat per key')
result = p.run()
result.wait_until_finish()
# Verify the values of metrics are correct.
word_counter_filter = MetricsFilter().with_name('word_counter')
query_result = result.metrics().query(word_counter_filter)
if query_result['counters']:
word_counter = query_result['counters'][0]
self.assertEqual(word_counter.result, 5)
word_lengths_filter = MetricsFilter().with_name('word_lengths')
query_result = result.metrics().query(word_lengths_filter)
if query_result['counters']:
word_lengths = query_result['counters'][0]
self.assertEqual(word_lengths.result, 15)
word_len_dist_filter = MetricsFilter().with_name('word_len_dist')
query_result = result.metrics().query(word_len_dist_filter)
if query_result['distributions']:
word_len_dist = query_result['distributions'][0]
self.assertEqual(word_len_dist.result.mean, 3)
last_word_len_filter = MetricsFilter().with_name('last_word_len')
query_result = result.metrics().query(last_word_len_filter)
if query_result['gauges']:
last_word_len = query_result['gauges'][0]
self.assertIn(last_word_len.result.value, [1, 2, 3, 4, 5])
# Test that three different kinds of metrics work with the customized
# SortedConcatWithCounters CombineFn when the PCollection is empty.
def test_custormized_counters_in_combine_fn_empty(self):
p = TestPipeline()
input = p | beam.Create([])
# The result of concatenating all values regardless of key.
global_concat = (
input
| beam.Values()
| beam.CombineGlobally(SortedConcatWithCounters()))
# The (key, concatenated_string) pairs for all keys.
concat_per_key = (input | beam.CombinePerKey(SortedConcatWithCounters()))
# Verify the concatenated strings are correct.
assert_that(global_concat, equal_to(['']), label='global concat')
assert_that(concat_per_key, equal_to([]), label='concat per key')
result = p.run()
result.wait_until_finish()
# Verify the values of metrics are correct.
word_counter_filter = MetricsFilter().with_name('word_counter')
query_result = result.metrics().query(word_counter_filter)
if query_result['counters']:
word_counter = query_result['counters'][0]
self.assertEqual(word_counter.result, 0)
word_lengths_filter = MetricsFilter().with_name('word_lengths')
query_result = result.metrics().query(word_lengths_filter)
if query_result['counters']:
word_lengths = query_result['counters'][0]
self.assertEqual(word_lengths.result, 0)
word_len_dist_filter = MetricsFilter().with_name('word_len_dist')
query_result = result.metrics().query(word_len_dist_filter)
if query_result['distributions']:
word_len_dist = query_result['distributions'][0]
self.assertEqual(word_len_dist.result.count, 0)
last_word_len_filter = MetricsFilter().with_name('last_word_len')
query_result = result.metrics().query(last_word_len_filter)
# No element has ever been recorded.
self.assertFalse(query_result['gauges'])
class LatestTest(unittest.TestCase):
def test_globally(self):
l = [
window.TimestampedValue(3, 100),
window.TimestampedValue(1, 200),
window.TimestampedValue(2, 300)
]
with TestPipeline() as p:
# Map(lambda x: x) PTransform is added after Create here, because when
# a PCollection of TimestampedValues is created with Create PTransform,
# the timestamps are not assigned to it. Adding a Map forces the
# PCollection to go through a DoFn so that the PCollection consists of
# the elements with timestamps assigned to them instead of a PCollection
# of TimestampedValue(element, timestamp).
pcoll = p | Create(l) | Map(lambda x: x)
latest = pcoll | combine.Latest.Globally()
assert_that(latest, equal_to([2]))
# Now for global combines without default
windowed = pcoll | 'window' >> WindowInto(FixedWindows(180))
result_windowed = (
windowed
|
'latest wo defaults' >> combine.Latest.Globally().without_defaults())
assert_that(result_windowed, equal_to([3, 2]), label='latest-wo-defaults')
def test_globally_empty(self):
l = []
with TestPipeline() as p:
pc = p | Create(l) | Map(lambda x: x)
latest = pc | combine.Latest.Globally()
assert_that(latest, equal_to([None]))
def test_per_key(self):
l = [
window.TimestampedValue(('a', 1), 300),
window.TimestampedValue(('b', 3), 100),
window.TimestampedValue(('a', 2), 200)
]
with TestPipeline() as p:
pc = p | Create(l) | Map(lambda x: x)
latest = pc | combine.Latest.PerKey()
assert_that(latest, equal_to([('a', 1), ('b', 3)]))
def test_per_key_empty(self):
l = []
with TestPipeline() as p:
pc = p | Create(l) | Map(lambda x: x)
latest = pc | combine.Latest.PerKey()
assert_that(latest, equal_to([]))
class LatestCombineFnTest(unittest.TestCase):
def setUp(self):
self.fn = combine.LatestCombineFn()
def test_create_accumulator(self):
accumulator = self.fn.create_accumulator()
self.assertEqual(accumulator, (None, window.MIN_TIMESTAMP))
def test_add_input(self):
accumulator = self.fn.create_accumulator()
element = (1, 100)
new_accumulator = self.fn.add_input(accumulator, element)
self.assertEqual(new_accumulator, (1, 100))
def test_merge_accumulators(self):
accumulators = [(2, 400), (5, 100), (9, 200)]
merged_accumulator = self.fn.merge_accumulators(accumulators)
self.assertEqual(merged_accumulator, (2, 400))
def test_extract_output(self):
accumulator = (1, 100)
output = self.fn.extract_output(accumulator)
self.assertEqual(output, 1)
def test_with_input_types_decorator_violation(self):
l_int = [1, 2, 3]
l_dict = [{'a': 3}, {'g': 5}, {'r': 8}]
l_3_tuple = [(12, 31, 41), (12, 34, 34), (84, 92, 74)]
with self.assertRaises(TypeCheckError):
with TestPipeline() as p:
pc = p | Create(l_int)
_ = pc | beam.CombineGlobally(self.fn)
with self.assertRaises(TypeCheckError):
with TestPipeline() as p:
pc = p | Create(l_dict)
_ = pc | beam.CombineGlobally(self.fn)
with self.assertRaises(TypeCheckError):
with TestPipeline() as p:
pc = p | Create(l_3_tuple)
_ = pc | beam.CombineGlobally(self.fn)
@pytest.mark.it_validatesrunner
class CombineValuesTest(unittest.TestCase):
def test_gbk_immediately_followed_by_combine(self):
def merge(vals):
return "".join(vals)
with TestPipeline() as p:
result = (
p \
| Create([("key1", "foo"), ("key2", "bar"), ("key1", "foo")],
reshuffle=False) \
| beam.GroupByKey() \
| beam.CombineValues(merge) \
| beam.MapTuple(lambda k, v: '{}: {}'.format(k, v)))
assert_that(result, equal_to(['key1: foofoo', 'key2: bar']))
#
# Test cases for streaming.
#
@pytest.mark.it_validatesrunner
class TimestampCombinerTest(unittest.TestCase):
def test_combiner_earliest(self):
"""Test TimestampCombiner with EARLIEST."""
options = PipelineOptions(streaming=True)
with TestPipeline(options=options) as p:
result = (
p
| TestStream().add_elements([window.TimestampedValue(
('k', 100), 2)]).add_elements(
[window.TimestampedValue(
('k', 400), 7)]).advance_watermark_to_infinity()
| beam.WindowInto(
window.FixedWindows(10),
timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST)
| beam.CombinePerKey(sum))
records = (
result
| beam.Map(lambda e, ts=beam.DoFn.TimestampParam: (e, ts)))
# All the KV pairs are applied GBK using EARLIEST timestamp for the same
# key.
expected_window_to_elements = {
window.IntervalWindow(0, 10): [
(('k', 500), Timestamp(2)),
],
}
assert_that(
records,
equal_to_per_window(expected_window_to_elements),
use_global_window=False,
label='assert per window')
def test_combiner_latest(self):
"""Test TimestampCombiner with LATEST."""
options = PipelineOptions(streaming=True)
with TestPipeline(options=options) as p:
result = (
p
| TestStream().add_elements([window.TimestampedValue(
('k', 100), 2)]).add_elements(
[window.TimestampedValue(
('k', 400), 7)]).advance_watermark_to_infinity()
| beam.WindowInto(
window.FixedWindows(10),
timestamp_combiner=TimestampCombiner.OUTPUT_AT_LATEST)
| beam.CombinePerKey(sum))
records = (
result
| beam.Map(lambda e, ts=beam.DoFn.TimestampParam: (e, ts)))
# All the KV pairs are applied GBK using LATEST timestamp for
# the same key.
expected_window_to_elements = {
window.IntervalWindow(0, 10): [
(('k', 500), Timestamp(7)),
],
}
assert_that(
records,
equal_to_per_window(expected_window_to_elements),
use_global_window=False,
label='assert per window')
class CombineGloballyTest(unittest.TestCase):
def test_combine_globally_for_unbounded_source_with_default(self):
# this error is logged since the below combination is ill-defined.
with self.assertLogs() as captured_logs:
with TestPipeline() as p:
_ = (
p
| PeriodicImpulse(
start_timestamp=time.time(),
stop_timestamp=time.time() + 4,
fire_interval=1,
apply_windowing=False,
)
| beam.Map(lambda x: ('c', 1))
| beam.WindowInto(
window.GlobalWindows(),
trigger=trigger.Repeatedly(trigger.AfterCount(2)),
accumulation_mode=trigger.AccumulationMode.DISCARDING,
)
| beam.combiners.Count.Globally())
self.assertIn('unbounded collections', '\n'.join(captured_logs.output))
def test_combine_globally_for_unbounded_source_without_defaults(self):
# this is the supported case
with TestPipeline() as p:
_ = (
p
| PeriodicImpulse(
start_timestamp=time.time(),
stop_timestamp=time.time() + 4,
fire_interval=1,
apply_windowing=False,
)
| beam.Map(lambda x: 1)
| beam.WindowInto(
window.GlobalWindows(),
trigger=trigger.Repeatedly(trigger.AfterCount(2)),
accumulation_mode=trigger.AccumulationMode.DISCARDING,
)
| beam.CombineGlobally(sum).without_defaults())
def get_common_items(sets, excluded_chars=""):
# set.intersection() takes multiple sets as separete arguments.
# We unpack the `sets` list into multiple arguments with the * operator.
# The combine transform might give us an empty list of `sets`,
# so we use a list with an empty set as a default value.
common = set.intersection(*(sets or [set()]))
return common.difference(excluded_chars)
class CombinerWithSideInputs(unittest.TestCase):
def test_cpk_with_side_input(self):
test_cases = [(get_common_items, True),
(beam.CombineFn.from_callable(get_common_items), True),
(get_common_items, False),
(beam.CombineFn.from_callable(get_common_items), False)]
for combiner, with_kwarg in test_cases:
self._check_combineperkey_with_side_input(combiner, with_kwarg)
self._check_combineglobally_with_side_input(combiner, with_kwarg)
def _check_combineperkey_with_side_input(self, combiner, with_kwarg):
with beam.Pipeline() as pipeline:
pc = (pipeline | beam.Create(['🍅']))
if with_kwarg:
cpk = beam.CombinePerKey(
combiner, excluded_chars=beam.pvalue.AsSingleton(pc))
else:
cpk = beam.CombinePerKey(combiner, beam.pvalue.AsSingleton(pc))
common_items = (
pipeline
| 'Create produce' >> beam.Create([
{'🍓', '🥕', '🍌', '🍅', '🌶️'},
{'🍇', '🥕', '🥝', '🍅', '🥔'},
{'🍉', '🥕', '🍆', '🍅', '🍍'},
{'🥑', '🥕', '🌽', '🍅', '🥥'},
])
| beam.WithKeys(lambda x: None)
| cpk)
assert_that(common_items, equal_to([(None, {'🥕'})]))
def _check_combineglobally_with_side_input(self, combiner, with_kwarg):
with beam.Pipeline() as pipeline:
pc = (pipeline | beam.Create(['🍅']))
if with_kwarg:
cpk = beam.CombineGlobally(
combiner, excluded_chars=beam.pvalue.AsSingleton(pc))
else:
cpk = beam.CombineGlobally(combiner, beam.pvalue.AsSingleton(pc))
common_items = (
pipeline
| 'Create produce' >> beam.Create([
{'🍓', '🥕', '🍌', '🍅', '🌶️'},
{'🍇', '🥕', '🥝', '🍅', '🥔'},
{'🍉', '🥕', '🍆', '🍅', '🍍'},
{'🥑', '🥕', '🌽', '🍅', '🥥'},
])
| cpk)
assert_that(common_items, equal_to([{'🥕'}]))
def test_combinefn_methods_with_side_input(self):
# Test that the expected combinefn methods are called with the
# expected arguments when using side inputs in CombinePerKey.
with tempfile.TemporaryDirectory() as tmp_dirname:
fname = str(Path(tmp_dirname) / "combinefn_calls.json")
with open(fname, "w") as f:
json.dump({}, f)
def set_in_json(key, values):
current_json = {}
if os.path.exists(fname):
with open(fname, "r") as f:
current_json = json.load(f)
current_json[key] = values
with open(fname, "w") as f:
json.dump(current_json, f)
class MyCombiner(beam.CombineFn):
def create_accumulator(self, *args, **kwargs):
set_in_json("create_accumulator_args", args)
set_in_json("create_accumulator_kwargs", kwargs)
return args, kwargs
def add_input(self, accumulator, input, *args, **kwargs):
set_in_json("add_input_args", args)
set_in_json("add_input_kwargs", kwargs)
return accumulator
def merge_accumulators(self, accumulators, *args, **kwargs):
set_in_json("merge_accumulators_args", args)
set_in_json("merge_accumulators_kwargs", kwargs)
return args, kwargs
def compact(self, accumulator, *args, **kwargs):
set_in_json("compact_args", args)
set_in_json("compact_kwargs", kwargs)
return accumulator
def extract_output(self, accumulator, *args, **kwargs):
set_in_json("extract_output_args", args)
set_in_json("extract_output_kwargs", kwargs)
return accumulator
with beam.Pipeline() as p:
static_pos_arg = 0
deferred_pos_arg = beam.pvalue.AsSingleton(
p | "CreateDeferredSideInput" >> beam.Create([1]))
static_kwarg = 2
deferred_kwarg = beam.pvalue.AsSingleton(
p | "CreateDeferredSideInputKwarg" >> beam.Create([3]))
res = (
p
| "CreateInputs" >> beam.Create([(None, None)])
| beam.CombinePerKey(
MyCombiner(),
static_pos_arg,
deferred_pos_arg,
static_kwarg=static_kwarg,
deferred_kwarg=deferred_kwarg))
assert_that(
res,
equal_to([
(None, ((0, 1), {
'static_kwarg': 2, 'deferred_kwarg': 3
}))
]))
# Check that the combinefn was called with the expected arguments
with open(fname, "r") as f:
data = json.load(f)
expected_args = [0, 1]
expected_kwargs = {"static_kwarg": 2, "deferred_kwarg": 3}
method_names = [
"create_accumulator",
"compact",
"add_input",
"merge_accumulators",
"extract_output"
]
for key in method_names:
print(f"Checking {key}")
self.assertEqual(data[key + "_args"], expected_args)
self.assertEqual(data[key + "_kwargs"], expected_kwargs)
def test_cpk_with_windows(self):
# With global window side input
with TestPipeline() as p:
def sum_with_floor(vals, min_value=0):
vals_sum = sum(vals)
if vals_sum < min_value:
vals_sum += min_value
return vals_sum
res = (
p
| "CreateInputs" >> beam.Create([1, 2, 100, 101, 102])
| beam.Map(lambda x: window.TimestampedValue(('k', x), x))
| beam.WindowInto(FixedWindows(99))
| beam.CombinePerKey(
sum_with_floor,
min_value=pvalue.AsSingleton(p | beam.Create([100]))))
assert_that(res, equal_to([('k', 103), ('k', 303)]))
# with matching window side input
with TestPipeline() as p:
min_value = (
p
| "CreateMinValue" >> beam.Create([
window.TimestampedValue(50, 5),
window.TimestampedValue(1000, 100)
])
| "WindowSideInputs" >> beam.WindowInto(FixedWindows(99)))
res = (
p
| "CreateInputs" >> beam.Create([1, 2, 100, 101, 102])
| beam.Map(lambda x: window.TimestampedValue(('k', x), x))
| beam.WindowInto(FixedWindows(99))
| beam.CombinePerKey(
sum_with_floor, min_value=pvalue.AsSingleton(min_value)))
assert_that(res, equal_to([('k', 53), ('k', 1303)]))
if __name__ == '__main__':
unittest.main()