blob: fbc36f38011151b840e60cd48eeb8c5ffcf7120d [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
import collections
import gc
import logging
import os
import random
import re
import shutil
import tempfile
import threading
import time
import traceback
import typing
import unittest
import uuid
from typing import Any
from typing import Dict
from typing import Iterator
from typing import List
from typing import Tuple
from typing import no_type_check
import hamcrest # pylint: disable=ungrouped-imports
import numpy as np
import pytest
from hamcrest.core.matcher import Matcher
from hamcrest.core.string_description import StringDescription
from tenacity import retry
from tenacity import stop_after_attempt
import apache_beam as beam
from apache_beam.coders import coders
from apache_beam.coders.coders import StrUtf8Coder
from apache_beam.io import restriction_trackers
from apache_beam.io.watermark_estimators import ManualWatermarkEstimator
from apache_beam.metrics import monitoring_infos
from apache_beam.metrics.execution import MetricKey
from apache_beam.metrics.metricbase import MetricName
from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.portability import python_urns
from apache_beam.runners.portability import fn_api_runner
from apache_beam.runners.portability.fn_api_runner import fn_runner
from apache_beam.runners.sdf_utils import RestrictionTrackerView
from apache_beam.runners.worker import data_plane
from apache_beam.runners.worker import statesampler
from apache_beam.runners.worker.operations import InefficientExecutionWarning
from apache_beam.testing.synthetic_pipeline import SyntheticSDFAsSource
from apache_beam.testing.test_stream import TestStream
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.tools import utils
from apache_beam.transforms import environments
from apache_beam.transforms import userstate
from apache_beam.transforms import window
from apache_beam.utils import timestamp
from apache_beam.utils import windowed_value
if statesampler.FAST_SAMPLER:
DEFAULT_SAMPLING_PERIOD_MS = statesampler.DEFAULT_SAMPLING_PERIOD_MS
else:
DEFAULT_SAMPLING_PERIOD_MS = 0
_LOGGER = logging.getLogger(__name__)
def _matcher_or_equal_to(value_or_matcher):
"""Pass-thru for matchers, and wraps value inputs in an equal_to matcher."""
if value_or_matcher is None:
return None
if isinstance(value_or_matcher, Matcher):
return value_or_matcher
return hamcrest.equal_to(value_or_matcher)
def has_urn_and_labels(mi, urn, labels):
"""Returns true if it the monitoring_info contains the labels and urn."""
def contains_labels(mi, labels):
# Check all the labels and their values exist in the monitoring_info
return all(item in mi.labels.items() for item in labels.items())
return contains_labels(mi, labels) and mi.urn == urn
class FnApiRunnerTest(unittest.TestCase):
def create_pipeline(self, is_drain=False):
return beam.Pipeline(runner=fn_api_runner.FnApiRunner(is_drain=is_drain))
def test_assert_that(self):
# TODO: figure out a way for fn_api_runner to parse and raise the
# underlying exception.
with self.assertRaisesRegex(Exception, 'Failed assert'):
with self.create_pipeline() as p:
assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))
def test_create(self):
with self.create_pipeline() as p:
assert_that(p | beam.Create(['a', 'b']), equal_to(['a', 'b']))
def test_pardo(self):
with self.create_pipeline() as p:
res = (
p
| beam.Create(['a', 'bc'])
| beam.Map(lambda e: e * 2)
| beam.Map(lambda e: e + 'x'))
assert_that(res, equal_to(['aax', 'bcbcx']))
def test_batch_pardo(self):
with self.create_pipeline() as p:
res = (
p
| beam.Create(np.array([1, 2, 3], dtype=np.int64)).with_output_types(
np.int64)
| beam.ParDo(ArrayMultiplyDoFn())
| beam.Map(lambda x: x * 3))
assert_that(res, equal_to([6, 12, 18]))
def test_batch_pardo_override_type_inference(self):
class ArrayMultiplyDoFnOverride(beam.DoFn):
def process_batch(self, batch, *unused_args,
**unused_kwargs) -> Iterator[np.ndarray]:
assert isinstance(batch, np.ndarray)
yield batch * 2
# infer_output_type must be defined (when there's no process method),
# otherwise we don't know the input type is the same as output type.
def infer_output_type(self, input_type):
return input_type
def get_input_batch_type(self, input_element_type):
from apache_beam.typehints.batch import NumpyArray
return NumpyArray[input_element_type]
def get_output_batch_type(self, input_element_type):
return self.get_input_batch_type(input_element_type)
with self.create_pipeline() as p:
res = (
p
| beam.Create(np.array([1, 2, 3], dtype=np.int64)).with_output_types(
np.int64)
| beam.ParDo(ArrayMultiplyDoFnOverride())
| beam.Map(lambda x: x * 3))
assert_that(res, equal_to([6, 12, 18]))
def test_batch_pardo_trigger_flush(self):
try:
utils.check_compiled('apache_beam.coders.coder_impl')
except RuntimeError:
self.skipTest(
'https://github.com/apache/beam/issues/21643: FnRunnerTest with '
'non-trivial inputs flakes in non-cython environments')
with self.create_pipeline() as p:
res = (
p
# Pass more than GeneralPurposeConsumerSet.MAX_BATCH_SIZE elements
# here to make sure we exercise the batch size limit.
| beam.Create(np.array(range(5000),
dtype=np.int64)).with_output_types(np.int64)
| beam.ParDo(ArrayMultiplyDoFn())
| beam.Map(lambda x: x * 3))
assert_that(res, equal_to([i * 2 * 3 for i in range(5000)]))
def test_batch_rebatch_pardos(self):
# Should raise a warning about the rebatching that mentions:
# - The consuming DoFn
# - The output batch type of the producer
# - The input batch type of the consumer
with self.assertWarnsRegex(InefficientExecutionWarning,
r'ListPlusOneDoFn.*NumpyArray.*List\[int64\]'):
with self.create_pipeline() as p:
res = (
p
| beam.Create(np.array([1, 2, 3],
dtype=np.int64)).with_output_types(np.int64)
| beam.ParDo(ArrayMultiplyDoFn())
| beam.ParDo(ListPlusOneDoFn())
| beam.Map(lambda x: x * 3))
assert_that(res, equal_to([9, 15, 21]))
def test_batch_pardo_fusion_break(self):
class NormalizeDoFn(beam.DoFn):
@no_type_check
def process_batch(
self,
batch: np.ndarray,
mean: np.float64,
) -> Iterator[np.ndarray]:
assert isinstance(batch, np.ndarray)
yield batch - mean
# infer_output_type must be defined (when there's no process method),
# otherwise we don't know the input type is the same as output type.
def infer_output_type(self, input_type):
return np.float64
with self.create_pipeline() as p:
pc = (
p
| beam.Create(np.array([1, 2, 3], dtype=np.int64)).with_output_types(
np.int64)
| beam.ParDo(ArrayMultiplyDoFn()))
res = (
pc
| beam.ParDo(
NormalizeDoFn(),
mean=beam.pvalue.AsSingleton(
pc | beam.CombineGlobally(beam.combiners.MeanCombineFn()))))
assert_that(res, equal_to([-2, 0, 2]))
def test_batch_pardo_dofn_params(self):
class ConsumeParamsDoFn(beam.DoFn):
@no_type_check
def process_batch(
self,
batch: np.ndarray,
ts=beam.DoFn.TimestampParam,
pane_info=beam.DoFn.PaneInfoParam,
) -> Iterator[np.ndarray]:
assert isinstance(batch, np.ndarray)
assert isinstance(ts, timestamp.Timestamp)
assert isinstance(pane_info, windowed_value.PaneInfo)
yield batch * ts.seconds()
# infer_output_type must be defined (when there's no process method),
# otherwise we don't know the input type is the same as output type.
def infer_output_type(self, input_type):
return input_type
with self.create_pipeline() as p:
res = (
p
| beam.Create(np.array(range(10), dtype=np.int64)).with_output_types(
np.int64)
| beam.Map(lambda t: window.TimestampedValue(t, int(t % 2))).
with_output_types(np.int64)
| beam.ParDo(ConsumeParamsDoFn()))
assert_that(res, equal_to([0, 1, 0, 3, 0, 5, 0, 7, 0, 9]))
def test_batch_pardo_window_param(self):
class PerWindowDoFn(beam.DoFn):
@no_type_check
def process_batch(
self,
batch: np.ndarray,
window=beam.DoFn.WindowParam,
) -> Iterator[np.ndarray]:
yield batch * window.start.seconds()
# infer_output_type must be defined (when there's no process method),
# otherwise we don't know the input type is the same as output type.
def infer_output_type(self, input_type):
return input_type
with self.create_pipeline() as p:
res = (
p
| beam.Create(np.array(range(10), dtype=np.int64)).with_output_types(
np.int64)
| beam.Map(lambda t: window.TimestampedValue(t, int(t))).
with_output_types(np.int64)
| beam.WindowInto(window.FixedWindows(5))
| beam.ParDo(PerWindowDoFn()))
assert_that(res, equal_to([0, 0, 0, 0, 0, 25, 30, 35, 40, 45]))
def test_batch_pardo_overlapping_windows(self):
class PerWindowDoFn(beam.DoFn):
@no_type_check
def process_batch(self,
batch: np.ndarray,
window=beam.DoFn.WindowParam) -> Iterator[np.ndarray]:
yield batch * window.start.seconds()
# infer_output_type must be defined (when there's no process method),
# otherwise we don't know the input type is the same as output type.
def infer_output_type(self, input_type):
return input_type
with self.create_pipeline() as p:
res = (
p
| beam.Create(np.array(range(10), dtype=np.int64)).with_output_types(
np.int64)
| beam.Map(lambda t: window.TimestampedValue(t, int(t))).
with_output_types(np.int64)
| beam.WindowInto(window.SlidingWindows(size=5, period=3))
| beam.ParDo(PerWindowDoFn()))
assert_that(res, equal_to([ 0*-3, 1*-3, # [-3, 2)
0*0, 1*0, 2*0, 3* 0, 4* 0, # [ 0, 5)
3*3, 4*3, 5*3, 6* 3, 7* 3, # [ 3, 8)
6*6, 7*6, 8*6, 9* 6, # [ 6, 11)
9*9 # [ 9, 14)
]))
def test_batch_to_element_pardo(self):
class ArraySumDoFn(beam.DoFn):
@beam.DoFn.yields_elements
def process_batch(self, batch: np.ndarray, *unused_args,
**unused_kwargs) -> Iterator[np.int64]:
yield batch.sum()
def infer_output_type(self, input_type):
assert input_type == np.int64
return np.int64
with self.create_pipeline() as p:
res = (
p
| beam.Create(np.array(range(100), dtype=np.int64)).with_output_types(
np.int64)
| beam.ParDo(ArrayMultiplyDoFn())
| beam.ParDo(ArraySumDoFn())
| beam.CombineGlobally(sum))
assert_that(res, equal_to([99 * 50 * 2]))
def test_element_to_batch_pardo(self):
class ArrayProduceDoFn(beam.DoFn):
@beam.DoFn.yields_batches
def process(self, element: np.int64, *unused_args,
**unused_kwargs) -> Iterator[np.ndarray]:
yield np.array([element] * int(element))
# infer_output_type must be defined (when there's no process method),
# otherwise we don't know the input type is the same as output type.
def infer_output_type(self, input_type):
return np.int64
with self.create_pipeline() as p:
res = (
p
| beam.Create(np.array([1, 2, 3], dtype=np.int64)).with_output_types(
np.int64)
| beam.ParDo(ArrayProduceDoFn())
| beam.ParDo(ArrayMultiplyDoFn())
| beam.Map(lambda x: x * 3))
assert_that(res, equal_to([6, 12, 12, 18, 18, 18]))
def test_pardo_large_input(self):
try:
utils.check_compiled('apache_beam.coders.coder_impl')
except RuntimeError:
self.skipTest(
'https://github.com/apache/beam/issues/21643: FnRunnerTest with '
'non-trivial inputs flakes in non-cython environments')
with self.create_pipeline() as p:
res = (
p
| beam.Create(np.array(range(5000),
dtype=np.int64)).with_output_types(np.int64)
| beam.Map(lambda e: e * 2)
| beam.Map(lambda e: e + 3))
assert_that(res, equal_to([(i * 2) + 3 for i in range(5000)]))
def test_pardo_side_outputs(self):
def tee(elem, *tags):
for tag in tags:
if tag in elem:
yield beam.pvalue.TaggedOutput(tag, elem)
with self.create_pipeline() as p:
xy = (
p
| 'Create' >> beam.Create(['x', 'y', 'xy'])
| beam.FlatMap(tee, 'x', 'y').with_outputs())
assert_that(xy.x, equal_to(['x', 'xy']), label='x')
assert_that(xy.y, equal_to(['y', 'xy']), label='y')
def test_pardo_side_and_main_outputs(self):
def even_odd(elem):
yield elem
yield beam.pvalue.TaggedOutput('odd' if elem % 2 else 'even', elem)
with self.create_pipeline() as p:
ints = p | beam.Create([1, 2, 3])
named = ints | 'named' >> beam.FlatMap(even_odd).with_outputs(
'even', 'odd', main='all')
assert_that(named.all, equal_to([1, 2, 3]), label='named.all')
assert_that(named.even, equal_to([2]), label='named.even')
assert_that(named.odd, equal_to([1, 3]), label='named.odd')
unnamed = ints | 'unnamed' >> beam.FlatMap(even_odd).with_outputs()
unnamed[None] | beam.Map(id) # pylint: disable=expression-not-assigned
assert_that(unnamed[None], equal_to([1, 2, 3]), label='unnamed.all')
assert_that(unnamed.even, equal_to([2]), label='unnamed.even')
assert_that(unnamed.odd, equal_to([1, 3]), label='unnamed.odd')
def test_pardo_side_inputs(self):
def cross_product(elem, sides):
for side in sides:
yield elem, side
with self.create_pipeline() as p:
main = p | 'main' >> beam.Create(['a', 'b', 'c'])
side = p | 'side' >> beam.Create(['x', 'y'])
assert_that(
main | beam.FlatMap(cross_product, beam.pvalue.AsList(side)),
equal_to([('a', 'x'), ('b', 'x'), ('c', 'x'), ('a', 'y'), ('b', 'y'),
('c', 'y')]))
def test_pardo_side_input_dependencies(self):
##
# The issue that this test surfaces is that whenever a PCollection is
# consumed as main input by several stages, we have a bug.
#
# The bug is: A stage assumes that it has run if its upstream PCollection
# has watermark=MAX. If multiple stages depend on a single PCollection, then
# the first stage that runs it will set its watermar to MAX, and other
# stages will think they've run.
#
# How to resolve?
# Option1: to make sure that a PCollection's watermark only advances with
# its consumption by all consumers?
# - I tested this and it didn't seem fruitful/obvious. (YET)
#
# Option2: Change execution schema: A PCollection's watermark represents
# its *production* watermark, not its *consumption* watermark.(?)
with self.create_pipeline() as p:
inputs = [p | beam.Create([None])]
for k in range(1, 10):
inputs.append(
inputs[0] | beam.ParDo(
ExpectingSideInputsFn(f'Do{k}'),
*[beam.pvalue.AsList(inputs[s]) for s in range(1, k)]))
def test_flatmap_numpy_array(self):
with self.create_pipeline() as p:
pc = (
p
| beam.Create([np.array(range(10))])
| beam.FlatMap(lambda arr: arr))
assert_that(pc, equal_to([np.int64(i) for i in range(10)]))
@unittest.skip('https://github.com/apache/beam/issues/21228')
def test_pardo_side_input_sparse_dependencies(self):
with self.create_pipeline() as p:
inputs = []
def choose_input(s):
return inputs[(389 + s * 5077) % len(inputs)]
for k in range(20):
num_inputs = int((k * k % 16)**0.5)
if num_inputs == 0:
inputs.append(p | f'Create{k}' >> beam.Create([f'Create{k}']))
else:
inputs.append(
choose_input(0) | beam.ParDo(
ExpectingSideInputsFn(f'Do{k}'),
*[
beam.pvalue.AsList(choose_input(s))
for s in range(1, num_inputs)
]))
def test_pardo_windowed_side_inputs(self):
with self.create_pipeline() as p:
# Now with some windowing.
pcoll = p | beam.Create(list(
range(10))) | beam.Map(lambda t: window.TimestampedValue(t, t))
# Intentionally choosing non-aligned windows to highlight the transition.
main = pcoll | 'WindowMain' >> beam.WindowInto(window.FixedWindows(5))
side = pcoll | 'WindowSide' >> beam.WindowInto(window.FixedWindows(7))
res = main | beam.Map(
lambda x, s: (x, sorted(s)), beam.pvalue.AsList(side))
assert_that(
res,
equal_to([
# The window [0, 5) maps to the window [0, 7).
(0, list(range(7))),
(1, list(range(7))),
(2, list(range(7))),
(3, list(range(7))),
(4, list(range(7))),
# The window [5, 10) maps to the window [7, 14).
(5, list(range(7, 10))),
(6, list(range(7, 10))),
(7, list(range(7, 10))),
(8, list(range(7, 10))),
(9, list(range(7, 10)))
]),
label='windowed')
def test_flattened_side_input(self, with_transcoding=True):
with self.create_pipeline() as p:
main = p | 'main' >> beam.Create([None])
side1 = p | 'side1' >> beam.Create([('a', 1)])
side2 = p | 'side2' >> beam.Create([('b', 2)])
if with_transcoding:
# Also test non-matching coder types (transcoding required)
third_element = [('another_type')]
else:
third_element = [('b', 3)]
side3 = p | 'side3' >> beam.Create(third_element)
side = (side1, side2) | beam.Flatten()
assert_that(
main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)),
equal_to([(None, {
'a': 1, 'b': 2
})]),
label='CheckFlattenAsSideInput')
assert_that((side, side3) | 'FlattenAfter' >> beam.Flatten(),
equal_to([('a', 1), ('b', 2)] + third_element),
label='CheckFlattenOfSideInput')
def test_gbk_side_input(self):
with self.create_pipeline() as p:
main = p | 'main' >> beam.Create([None])
side = p | 'side' >> beam.Create([('a', 1)]) | beam.GroupByKey()
assert_that(
main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)),
equal_to([(None, {
'a': [1]
})]))
def test_multimap_side_input(self):
with self.create_pipeline() as p:
main = p | 'main' >> beam.Create(['a', 'b'])
side = p | 'side' >> beam.Create([('a', 1), ('b', 2), ('a', 3)])
assert_that(
main | beam.Map(
lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)),
equal_to([('a', [1, 3]), ('b', [2])]))
def test_multimap_multiside_input(self):
# A test where two transforms in the same stage consume the same PCollection
# twice as side input.
with self.create_pipeline() as p:
main = p | 'main' >> beam.Create(['a', 'b'])
side = p | 'side' >> beam.Create([('a', 1), ('b', 2), ('a', 3)])
assert_that(
main | 'first map' >> beam.Map(
lambda k,
d,
l: (k, sorted(d[k]), sorted([e[1] for e in l])),
beam.pvalue.AsMultiMap(side),
beam.pvalue.AsList(side))
| 'second map' >> beam.Map(
lambda k,
d,
l: (k[0], sorted(d[k[0]]), sorted([e[1] for e in l])),
beam.pvalue.AsMultiMap(side),
beam.pvalue.AsList(side)),
equal_to([('a', [1, 3], [1, 2, 3]), ('b', [2], [1, 2, 3])]))
def test_multimap_side_input_type_coercion(self):
with self.create_pipeline() as p:
main = p | 'main' >> beam.Create(['a', 'b'])
# The type of this side-input is forced to Any (overriding type
# inference). Without type coercion to Tuple[Any, Any], the usage of this
# side-input in AsMultiMap() below should fail.
side = (
p | 'side' >> beam.Create([('a', 1), ('b', 2),
('a', 3)]).with_output_types(typing.Any))
assert_that(
main | beam.Map(
lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)),
equal_to([('a', [1, 3]), ('b', [2])]))
def test_pardo_unfusable_side_inputs(self):
def cross_product(elem, sides):
for side in sides:
yield elem, side
with self.create_pipeline() as p:
pcoll = p | beam.Create(['a', 'b'])
assert_that(
pcoll | beam.FlatMap(cross_product, beam.pvalue.AsList(pcoll)),
equal_to([('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')]))
def test_pardo_unfusable_side_inputs_with_separation(self):
def cross_product(elem, sides):
for side in sides:
yield elem, side
with self.create_pipeline() as p:
pcoll = p | beam.Create(['a', 'b'])
derived = ((pcoll, ) | beam.Flatten()
| beam.Map(lambda x: (x, x))
| beam.GroupByKey()
| 'Unkey' >> beam.Map(lambda kv: kv[0]))
assert_that(
pcoll | beam.FlatMap(cross_product, beam.pvalue.AsList(derived)),
equal_to([('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')]))
def test_pardo_state_only(self):
index_state_spec = userstate.CombiningValueStateSpec('index', sum)
value_and_index_state_spec = userstate.ReadModifyWriteStateSpec(
'value:index', StrUtf8Coder())
# TODO(ccy): State isn't detected with Map/FlatMap.
class AddIndex(beam.DoFn):
def process(
self,
kv,
index=beam.DoFn.StateParam(index_state_spec),
value_and_index=beam.DoFn.StateParam(value_and_index_state_spec)):
k, v = kv
index.add(1)
value_and_index.write('%s:%s' % (v, index.read()))
yield k, v, index.read(), value_and_index.read()
inputs = [('A', 'a')] * 2 + [('B', 'b')] * 3
expected = [('A', 'a', 1, 'a:1'), ('A', 'a', 2, 'a:2'),
('B', 'b', 1, 'b:1'), ('B', 'b', 2, 'b:2'),
('B', 'b', 3, 'b:3')]
with self.create_pipeline() as p:
assert_that(
p | beam.Create(inputs) | beam.ParDo(AddIndex()), equal_to(expected))
@unittest.skip('TestStream not yet supported')
def test_teststream_pardo_timers(self):
timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
class TimerDoFn(beam.DoFn):
def process(self, element, timer=beam.DoFn.TimerParam(timer_spec)):
unused_key, ts = element
timer.set(ts)
timer.set(2 * ts)
@userstate.on_timer(timer_spec)
def process_timer(self):
yield 'fired'
ts = (
TestStream().add_elements([('k1', 10)]) # Set timer for 20
.advance_watermark_to(100).add_elements([('k2', 100)
]) # Set timer for 200
.advance_watermark_to(1000))
with self.create_pipeline() as p:
_ = (
p
| ts
| beam.ParDo(TimerDoFn())
| beam.Map(lambda x, ts=beam.DoFn.TimestampParam: (x, ts)))
#expected = [('fired', ts) for ts in (20, 200)]
#assert_that(actual, equal_to(expected))
def test_pardo_timers(self):
timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
state_spec = userstate.CombiningValueStateSpec('num_called', sum)
class TimerDoFn(beam.DoFn):
def process(self, element, timer=beam.DoFn.TimerParam(timer_spec)):
unused_key, ts = element
timer.set(ts)
timer.set(2 * ts)
@userstate.on_timer(timer_spec)
def process_timer(
self,
ts=beam.DoFn.TimestampParam,
timer=beam.DoFn.TimerParam(timer_spec),
state=beam.DoFn.StateParam(state_spec)):
if state.read() == 0:
state.add(1)
timer.set(timestamp.Timestamp(micros=2 * ts.micros))
yield 'fired'
with self.create_pipeline() as p:
actual = (
p
| beam.Create([('k1', 10), ('k2', 100)])
| beam.ParDo(TimerDoFn())
| beam.Map(lambda x, ts=beam.DoFn.TimestampParam: (x, ts)))
expected = [('fired', ts) for ts in (20, 200, 40, 400)]
assert_that(actual, equal_to(expected))
def test_pardo_timers_clear(self):
timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
clear_timer_spec = userstate.TimerSpec(
'clear_timer', userstate.TimeDomain.WATERMARK)
class TimerDoFn(beam.DoFn):
def process(
self,
element,
timer=beam.DoFn.TimerParam(timer_spec),
clear_timer=beam.DoFn.TimerParam(clear_timer_spec)):
unused_key, ts = element
timer.set(ts)
timer.set(2 * ts)
clear_timer.set(ts)
clear_timer.clear()
@userstate.on_timer(timer_spec)
def process_timer(self):
yield 'fired'
@userstate.on_timer(clear_timer_spec)
def process_clear_timer(self):
yield 'should not fire'
with self.create_pipeline() as p:
actual = (
p
| beam.Create([('k1', 10), ('k2', 100)])
| beam.ParDo(TimerDoFn())
| beam.Map(lambda x, ts=beam.DoFn.TimestampParam: (x, ts)))
expected = [('fired', ts) for ts in (20, 200)]
assert_that(actual, equal_to(expected))
def test_pardo_state_timers(self):
self._run_pardo_state_timers(windowed=False)
def test_pardo_state_timers_non_standard_coder(self):
self._run_pardo_state_timers(windowed=False, key_type=Any)
def test_windowed_pardo_state_timers(self):
self._run_pardo_state_timers(windowed=True)
def _run_pardo_state_timers(self, windowed, key_type=None):
"""
:param windowed: If True, uses an interval window, otherwise a global window
:param key_type: Allows to override the inferred key type. This is useful to
test the use of non-standard coders, e.g. Python's FastPrimitivesCoder.
"""
state_spec = userstate.BagStateSpec('state', beam.coders.StrUtf8Coder())
timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
elements = list('abcdefgh')
key = 'key'
buffer_size = 3
class BufferDoFn(beam.DoFn):
def process(
self,
kv,
ts=beam.DoFn.TimestampParam,
timer=beam.DoFn.TimerParam(timer_spec),
state=beam.DoFn.StateParam(state_spec)):
_, element = kv
state.add(element)
buffer = state.read()
# For real use, we'd keep track of this size separately.
if len(list(buffer)) >= 3:
state.clear()
yield buffer
else:
timer.set(ts + 1)
@userstate.on_timer(timer_spec)
def process_timer(self, state=beam.DoFn.StateParam(state_spec)):
buffer = state.read()
state.clear()
yield buffer
def is_buffered_correctly(actual):
# Pickling self in the closure for asserts gives errors (only on jenkins).
self = FnApiRunnerTest('__init__')
# Acutal should be a grouping of the inputs into batches of size
# at most buffer_size, but the actual batching is nondeterministic
# based on ordering and trigger firing timing.
self.assertEqual(sorted(sum((list(b) for b in actual), [])), elements)
self.assertEqual(max(len(list(buffer)) for buffer in actual), buffer_size)
if windowed:
# Elements were assigned to windows based on their parity.
# Assert that each grouping consists of elements belonging to the
# same window to ensure states and timers were properly partitioned.
for b in actual:
parity = set(ord(e) % 2 for e in b)
self.assertEqual(1, len(parity), b)
with self.create_pipeline() as p:
actual = (
p
| beam.Create(elements)
# Send even and odd elements to different windows.
| beam.Map(lambda e: window.TimestampedValue(e, ord(e) % 2))
| beam.WindowInto(
window.FixedWindows(1) if windowed else window.GlobalWindows())
| beam.Map(lambda x: (key, x)).with_output_types(
Tuple[key_type if key_type else type(key), Any])
| beam.ParDo(BufferDoFn()))
assert_that(actual, is_buffered_correctly)
def test_pardo_dynamic_timer(self):
class DynamicTimerDoFn(beam.DoFn):
dynamic_timer_spec = userstate.TimerSpec(
'dynamic_timer', userstate.TimeDomain.WATERMARK)
def process(
self, element,
dynamic_timer=beam.DoFn.TimerParam(dynamic_timer_spec)):
dynamic_timer.set(element[1], dynamic_timer_tag=element[0])
@userstate.on_timer(dynamic_timer_spec)
def dynamic_timer_callback(
self,
tag=beam.DoFn.DynamicTimerTagParam,
timestamp=beam.DoFn.TimestampParam):
yield (tag, timestamp)
with self.create_pipeline() as p:
actual = (
p
| beam.Create([('key1', 10), ('key2', 20), ('key3', 30)])
| beam.ParDo(DynamicTimerDoFn()))
assert_that(actual, equal_to([('key1', 10), ('key2', 20), ('key3', 30)]))
def test_sdf(self):
class ExpandingStringsDoFn(beam.DoFn):
def process(
self,
element,
restriction_tracker=beam.DoFn.RestrictionParam(
ExpandStringsProvider())):
assert isinstance(restriction_tracker, RestrictionTrackerView)
cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
yield element[cur]
cur += 1
with self.create_pipeline() as p:
data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
actual = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()))
assert_that(actual, equal_to(list(''.join(data))))
def test_sdf_with_dofn_as_restriction_provider(self):
class ExpandingStringsDoFn(beam.DoFn, ExpandStringsProvider):
def process(
self, element, restriction_tracker=beam.DoFn.RestrictionParam()):
assert isinstance(restriction_tracker, RestrictionTrackerView)
cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
yield element[cur]
cur += 1
with self.create_pipeline() as p:
data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
actual = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()))
assert_that(actual, equal_to(list(''.join(data))))
def test_sdf_with_check_done_failed(self):
class ExpandingStringsDoFn(beam.DoFn):
def process(
self,
element,
restriction_tracker=beam.DoFn.RestrictionParam(
ExpandStringsProvider())):
assert isinstance(restriction_tracker, RestrictionTrackerView)
cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
yield element[cur]
cur += 1
return
with self.assertRaises(Exception):
with self.create_pipeline() as p:
data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
_ = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()))
def test_sdf_with_watermark_tracking(self):
class ExpandingStringsDoFn(beam.DoFn):
def process(
self,
element,
restriction_tracker=beam.DoFn.RestrictionParam(
ExpandStringsProvider()),
watermark_estimator=beam.DoFn.WatermarkEstimatorParam(
ManualWatermarkEstimator.default_provider())):
cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
watermark_estimator.set_watermark(timestamp.Timestamp(cur))
assert (
watermark_estimator.current_watermark() == timestamp.Timestamp(
cur))
yield element[cur]
if cur % 2 == 1:
restriction_tracker.defer_remainder(timestamp.Duration(micros=5))
return
cur += 1
with self.create_pipeline() as p:
data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
actual = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()))
assert_that(actual, equal_to(list(''.join(data))))
def test_sdf_with_dofn_as_watermark_estimator(self):
class ExpandingStringsDoFn(beam.DoFn, beam.WatermarkEstimatorProvider):
def initial_estimator_state(self, element, restriction):
return None
def create_watermark_estimator(self, state):
return beam.io.watermark_estimators.ManualWatermarkEstimator(state)
def process(
self,
element,
restriction_tracker=beam.DoFn.RestrictionParam(
ExpandStringsProvider()),
watermark_estimator=beam.DoFn.WatermarkEstimatorParam(
ManualWatermarkEstimator.default_provider())):
cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
watermark_estimator.set_watermark(timestamp.Timestamp(cur))
assert (
watermark_estimator.current_watermark() == timestamp.Timestamp(
cur))
yield element[cur]
if cur % 2 == 1:
restriction_tracker.defer_remainder(timestamp.Duration(micros=5))
return
cur += 1
with self.create_pipeline() as p:
data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
actual = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()))
assert_that(actual, equal_to(list(''.join(data))))
def run_sdf_initiated_checkpointing(self, is_drain=False):
counter = beam.metrics.Metrics.counter('ns', 'my_counter')
class ExpandStringsDoFn(beam.DoFn):
def process(
self,
element,
restriction_tracker=beam.DoFn.RestrictionParam(
ExpandStringsProvider())):
assert isinstance(restriction_tracker, RestrictionTrackerView)
cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
counter.inc()
yield element[cur]
if cur % 2 == 1:
restriction_tracker.defer_remainder()
return
cur += 1
with self.create_pipeline(is_drain=is_drain) as p:
data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
actual = (p | beam.Create(data) | beam.ParDo(ExpandStringsDoFn()))
assert_that(actual, equal_to(list(''.join(data))))
if isinstance(p.runner, fn_api_runner.FnApiRunner):
res = p.runner._latest_run_result
counters = res.metrics().query(
beam.metrics.MetricsFilter().with_name('my_counter'))['counters']
self.assertEqual(1, len(counters))
self.assertEqual(counters[0].committed, len(''.join(data)))
def test_sdf_with_sdf_initiated_checkpointing(self):
self.run_sdf_initiated_checkpointing(is_drain=False)
def test_draining_sdf_with_sdf_initiated_checkpointing(self):
self.run_sdf_initiated_checkpointing(is_drain=True)
def test_sdf_default_truncate_when_bounded(self):
class SimleSDF(beam.DoFn):
def process(
self,
element,
restriction_tracker=beam.DoFn.RestrictionParam(
OffsetRangeProvider(use_bounded_offset_range=True))):
assert isinstance(restriction_tracker, RestrictionTrackerView)
cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
yield cur
cur += 1
with self.create_pipeline(is_drain=True) as p:
actual = (p | beam.Create([10]) | beam.ParDo(SimleSDF()))
assert_that(actual, equal_to(range(10)))
def test_sdf_default_truncate_when_unbounded(self):
class SimleSDF(beam.DoFn):
def process(
self,
element,
restriction_tracker=beam.DoFn.RestrictionParam(
OffsetRangeProvider(use_bounded_offset_range=False))):
assert isinstance(restriction_tracker, RestrictionTrackerView)
cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
yield cur
cur += 1
with self.create_pipeline(is_drain=True) as p:
actual = (p | beam.Create([10]) | beam.ParDo(SimleSDF()))
assert_that(actual, equal_to([]))
def test_sdf_with_truncate(self):
class SimleSDF(beam.DoFn):
def process(
self,
element,
restriction_tracker=beam.DoFn.RestrictionParam(
OffsetRangeProviderWithTruncate())):
assert isinstance(restriction_tracker, RestrictionTrackerView)
cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
yield cur
cur += 1
with self.create_pipeline(is_drain=True) as p:
actual = (p | beam.Create([10]) | beam.ParDo(SimleSDF()))
assert_that(actual, equal_to(range(5)))
def test_group_by_key(self):
with self.create_pipeline() as p:
res = (
p
| beam.Create([('a', 1), ('a', 2), ('b', 3)])
| beam.GroupByKey()
| beam.Map(lambda k_vs: (k_vs[0], sorted(k_vs[1]))))
assert_that(res, equal_to([('a', [1, 2]), ('b', [3])]))
# Runners may special case the Reshuffle transform urn.
def test_reshuffle(self):
with self.create_pipeline() as p:
assert_that(
p | beam.Create([1, 2, 3]) | beam.Reshuffle(), equal_to([1, 2, 3]))
def test_flatten(self, with_transcoding=True):
with self.create_pipeline() as p:
if with_transcoding:
# Additional element which does not match with the first type
additional = [ord('d')]
else:
additional = ['d']
res = (
p | 'a' >> beam.Create(['a']),
p | 'bc' >> beam.Create(['b', 'c']),
p | 'd' >> beam.Create(additional)) | beam.Flatten()
assert_that(res, equal_to(['a', 'b', 'c'] + additional))
def test_flatten_same_pcollections(self, with_transcoding=True):
with self.create_pipeline() as p:
pc = p | beam.Create(['a', 'b'])
assert_that((pc, pc, pc) | beam.Flatten(), equal_to(['a', 'b'] * 3))
def test_combine_per_key(self):
with self.create_pipeline() as p:
res = (
p
| beam.Create([('a', 1), ('a', 2), ('b', 3)])
| beam.CombinePerKey(beam.combiners.MeanCombineFn()))
assert_that(res, equal_to([('a', 1.5), ('b', 3.0)]))
def test_read(self):
# Can't use NamedTemporaryFile as a context
# due to https://bugs.python.org/issue14243
temp_file = tempfile.NamedTemporaryFile(delete=False)
try:
temp_file.write(b'a\nb\nc')
temp_file.close()
with self.create_pipeline() as p:
assert_that(
p | beam.io.ReadFromText(temp_file.name), equal_to(['a', 'b', 'c']))
finally:
os.unlink(temp_file.name)
def test_windowing(self):
with self.create_pipeline() as p:
res = (
p
| beam.Create([1, 2, 100, 101, 102])
| beam.Map(lambda t: window.TimestampedValue(('k', t), t))
| beam.WindowInto(beam.transforms.window.Sessions(10))
| beam.GroupByKey()
| beam.Map(lambda k_vs1: (k_vs1[0], sorted(k_vs1[1]))))
assert_that(res, equal_to([('k', [1, 2]), ('k', [100, 101, 102])]))
def test_custom_merging_window(self):
with self.create_pipeline() as p:
res = (
p
| beam.Create([1, 2, 100, 101, 102])
| beam.Map(lambda t: window.TimestampedValue(('k', t), t))
| beam.WindowInto(CustomMergingWindowFn())
| beam.GroupByKey()
| beam.Map(lambda k_vs1: (k_vs1[0], sorted(k_vs1[1]))))
assert_that(
res, equal_to([('k', [1]), ('k', [101]), ('k', [2, 100, 102])]))
gc.collect()
from apache_beam.runners.portability.fn_api_runner.execution import GenericMergingWindowFn
self.assertEqual(GenericMergingWindowFn._HANDLES, {})
@unittest.skip('BEAM-9119: test is flaky')
def test_large_elements(self):
with self.create_pipeline() as p:
big = (
p
| beam.Create(['a', 'a', 'b'])
|
beam.Map(lambda x: (x, x * data_plane._DEFAULT_SIZE_FLUSH_THRESHOLD)))
side_input_res = (
big
| beam.Map(
lambda x,
side: (x[0], side.count(x[0])),
beam.pvalue.AsList(big | beam.Map(lambda x: x[0]))))
assert_that(
side_input_res,
equal_to([('a', 2), ('a', 2), ('b', 1)]),
label='side')
gbk_res = (big | beam.GroupByKey() | beam.Map(lambda x: x[0]))
assert_that(gbk_res, equal_to(['a', 'b']), label='gbk')
def test_error_message_includes_stage(self):
with self.assertRaises(BaseException) as e_cm:
with self.create_pipeline() as p:
def raise_error(x):
raise RuntimeError(
'This error is expected and does not indicate a test failure.')
# pylint: disable=expression-not-assigned
(
p
| beam.Create(['a', 'b'])
| 'StageA' >> beam.Map(lambda x: x)
| 'StageB' >> beam.Map(lambda x: x)
| 'StageC' >> beam.Map(raise_error)
| 'StageD' >> beam.Map(lambda x: x))
message = e_cm.exception.args[0]
self.assertIn('StageC', message)
self.assertNotIn('StageB', message)
def test_error_traceback_includes_user_code(self):
def first(x):
return second(x)
def second(x):
return third(x)
def third(x):
raise ValueError(
'This error is expected and does not indicate a test failure.')
try:
with self.create_pipeline() as p:
p | beam.Create([0]) | beam.Map(first) # pylint: disable=expression-not-assigned
except Exception: # pylint: disable=broad-except
message = traceback.format_exc()
else:
raise AssertionError('expected exception not raised')
self.assertIn('first', message)
self.assertIn('second', message)
self.assertIn('third', message)
def test_no_subtransform_composite(self):
class First(beam.PTransform):
def expand(self, pcolls):
return pcolls[0]
with self.create_pipeline() as p:
pcoll_a = p | 'a' >> beam.Create(['a'])
pcoll_b = p | 'b' >> beam.Create(['b'])
assert_that((pcoll_a, pcoll_b) | First(), equal_to(['a']))
def test_metrics(self, check_gauge=True):
p = self.create_pipeline()
counter = beam.metrics.Metrics.counter('ns', 'counter')
distribution = beam.metrics.Metrics.distribution('ns', 'distribution')
gauge = beam.metrics.Metrics.gauge('ns', 'gauge')
pcoll = p | beam.Create(['a', 'zzz'])
# pylint: disable=expression-not-assigned
pcoll | 'count1' >> beam.FlatMap(lambda x: counter.inc())
pcoll | 'count2' >> beam.FlatMap(lambda x: counter.inc(len(x)))
pcoll | 'dist' >> beam.FlatMap(lambda x: distribution.update(len(x)))
pcoll | 'gauge' >> beam.FlatMap(lambda x: gauge.set(3))
res = p.run()
res.wait_until_finish()
t1, t2 = res.metrics().query(beam.metrics.MetricsFilter()
.with_name('counter'))['counters']
self.assertEqual(t1.committed + t2.committed, 6)
dist, = res.metrics().query(beam.metrics.MetricsFilter()
.with_name('distribution'))['distributions']
self.assertEqual(
dist.committed.data, beam.metrics.cells.DistributionData(4, 2, 1, 3))
self.assertEqual(dist.committed.mean, 2.0)
if check_gauge:
gaug, = res.metrics().query(beam.metrics.MetricsFilter()
.with_name('gauge'))['gauges']
self.assertEqual(gaug.committed.value, 3)
def test_callbacks_with_exception(self):
elements_list = ['1', '2']
def raise_expetion():
raise Exception('raise exception when calling callback')
class FinalizebleDoFnWithException(beam.DoFn):
def process(
self, element, bundle_finalizer=beam.DoFn.BundleFinalizerParam):
bundle_finalizer.register(raise_expetion)
yield element
with self.create_pipeline() as p:
res = (
p
| beam.Create(elements_list)
| beam.ParDo(FinalizebleDoFnWithException()))
assert_that(res, equal_to(['1', '2']))
def test_register_finalizations(self):
event_recorder = EventRecorder(tempfile.gettempdir())
class FinalizableSplittableDoFn(beam.DoFn):
def process(
self,
element,
bundle_finalizer=beam.DoFn.BundleFinalizerParam,
restriction_tracker=beam.DoFn.RestrictionParam(
OffsetRangeProvider(
use_bounded_offset_range=True, checkpoint_only=True))):
# We use SDF to enforce finalization call happens by using
# self-initiated checkpoint.
if 'finalized' in event_recorder.events():
restriction_tracker.try_claim(
restriction_tracker.current_restriction().start)
yield element
restriction_tracker.try_claim(element)
return
if restriction_tracker.try_claim(
restriction_tracker.current_restriction().start):
bundle_finalizer.register(lambda: event_recorder.record('finalized'))
# We sleep here instead of setting a resume time since the resume time
# doesn't need to be honored.
time.sleep(1)
restriction_tracker.defer_remainder()
with self.create_pipeline() as p:
max_retries = 100
res = (
p
| beam.Create([max_retries])
| beam.ParDo(FinalizableSplittableDoFn()))
assert_that(res, equal_to([max_retries]))
event_recorder.cleanup()
def test_sdf_synthetic_source(self):
common_attrs = {
'key_size': 1,
'value_size': 1,
'initial_splitting_num_bundles': 2,
'initial_splitting_desired_bundle_size': 2,
'sleep_per_input_record_sec': 0,
'initial_splitting': 'const'
}
num_source_description = 5
min_num_record = 10
max_num_record = 20
# pylint: disable=unused-variable
source_descriptions = ([
dict({'num_records': random.randint(min_num_record, max_num_record)},
**common_attrs) for i in range(0, num_source_description)
])
total_num_records = 0
for source in source_descriptions:
total_num_records += source['num_records']
with self.create_pipeline() as p:
res = (
p
| beam.Create(source_descriptions)
| beam.ParDo(SyntheticSDFAsSource())
| beam.combiners.Count.Globally())
assert_that(res, equal_to([total_num_records]))
def test_create_value_provider_pipeline_option(self):
# Verify that the runner can execute a pipeline when there are value
# provider pipeline options
# pylint: disable=unused-variable
class FooOptions(PipelineOptions):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_value_provider_argument(
"--foo", help='a value provider argument', default="bar")
RuntimeValueProvider.set_runtime_options({})
with self.create_pipeline() as p:
assert_that(p | beam.Create(['a', 'b']), equal_to(['a', 'b']))
def _test_pack_combiners(self, assert_using_counter_names):
counter = beam.metrics.Metrics.counter('ns', 'num_values')
def min_with_counter(values):
counter.inc()
return min(values)
def max_with_counter(values):
counter.inc()
return max(values)
class PackableCombines(beam.PTransform):
def annotations(self):
return {python_urns.APPLY_COMBINER_PACKING: b''}
def expand(self, pcoll):
assert_that(
pcoll | 'PackableMin' >> beam.CombineGlobally(min_with_counter),
equal_to([10]),
label='AssertMin')
assert_that(
pcoll | 'PackableMax' >> beam.CombineGlobally(max_with_counter),
equal_to([30]),
label='AssertMax')
with self.create_pipeline() as p:
_ = p | beam.Create([10, 20, 30]) | PackableCombines()
res = p.run()
res.wait_until_finish()
packed_step_name_regex = (
r'.*Packed.*PackableMin.*CombinePerKey.*PackableMax.*CombinePerKey.*' +
'Pack.*')
counters = res.metrics().query(beam.metrics.MetricsFilter())['counters']
step_names = set(m.key.step for m in counters if m.key.step)
pipeline_options = p._options
if assert_using_counter_names:
if pipeline_options.view_as(StandardOptions).streaming:
self.assertFalse(
any(re.match(packed_step_name_regex, s) for s in step_names))
else:
self.assertTrue(
any(re.match(packed_step_name_regex, s) for s in step_names))
def test_pack_combiners(self):
self._test_pack_combiners(assert_using_counter_names=True)
# These tests are kept in a separate group so that they are
# not ran in the FnApiRunnerTestWithBundleRepeat which repeats
# bundle processing. This breaks the byte sampling metrics as
# it makes the probability of sampling far too small
# upon repeating bundle processing due to unncessarily incrementing
# the sampling counter.
class FnApiRunnerMetricsTest(unittest.TestCase):
def assert_has_counter(
self, mon_infos, urn, labels, value=None, ge_value=None):
# TODO(ajamato): Consider adding a matcher framework
found = 0
matches = []
for mi in mon_infos:
if has_urn_and_labels(mi, urn, labels):
extracted_value = monitoring_infos.extract_counter_value(mi)
if ge_value is not None:
if extracted_value >= ge_value:
found = found + 1
elif value is not None:
if extracted_value == value:
found = found + 1
else:
found = found + 1
ge_value_str = {'ge_value': ge_value} if ge_value else ''
value_str = {'value': value} if value else ''
self.assertEqual(
1,
found,
"Found (%s, %s) Expected only 1 monitoring_info for %s." % (
found,
matches,
(urn, labels, value_str, ge_value_str),
))
def assert_has_distribution(
self, mon_infos, urn, labels, sum=None, count=None, min=None, max=None):
# TODO(ajamato): Consider adding a matcher framework
sum = _matcher_or_equal_to(sum)
count = _matcher_or_equal_to(count)
min = _matcher_or_equal_to(min)
max = _matcher_or_equal_to(max)
found = 0
description = StringDescription()
for mi in mon_infos:
if has_urn_and_labels(mi, urn, labels):
(extracted_count, extracted_sum, extracted_min,
extracted_max) = monitoring_infos.extract_distribution(mi)
increment = 1
if sum is not None:
description.append_text(' sum: ')
sum.describe_to(description)
if not sum.matches(extracted_sum):
increment = 0
if count is not None:
description.append_text(' count: ')
count.describe_to(description)
if not count.matches(extracted_count):
increment = 0
if min is not None:
description.append_text(' min: ')
min.describe_to(description)
if not min.matches(extracted_min):
increment = 0
if max is not None:
description.append_text(' max: ')
max.describe_to(description)
if not max.matches(extracted_max):
increment = 0
found += increment
self.assertEqual(
1,
found,
"Found (%s) Expected only 1 monitoring_info for %s." % (
found,
(urn, labels, str(description)),
))
def create_pipeline(self):
return beam.Pipeline(runner=fn_api_runner.FnApiRunner())
def test_element_count_metrics(self):
class GenerateTwoOutputs(beam.DoFn):
def process(self, element):
yield str(element) + '1'
yield beam.pvalue.TaggedOutput('SecondOutput', str(element) + '2')
yield beam.pvalue.TaggedOutput('SecondOutput', str(element) + '2')
yield beam.pvalue.TaggedOutput('ThirdOutput', str(element) + '3')
class PassThrough(beam.DoFn):
def process(self, element):
yield element
p = self.create_pipeline()
# Produce enough elements to make sure byte sampling occurs.
num_source_elems = 100
pcoll = p | beam.Create(['a%d' % i for i in range(num_source_elems)],
reshuffle=False)
# pylint: disable=expression-not-assigned
pardo = (
'StepThatDoesTwoOutputs' >> beam.ParDo(
GenerateTwoOutputs()).with_outputs(
'SecondOutput', 'ThirdOutput', main='FirstAndMainOutput'))
# Actually feed pcollection to pardo
second_output, third_output, first_output = (pcoll | pardo)
# consume some of elements
merged = ((first_output, second_output, third_output) | beam.Flatten())
merged | ('PassThrough') >> beam.ParDo(PassThrough())
second_output | ('PassThrough2') >> beam.ParDo(PassThrough())
res = p.run()
res.wait_until_finish()
result_metrics = res.monitoring_metrics()
counters = result_metrics.monitoring_infos()
# All element count and byte count metrics must have a PCOLLECTION_LABEL.
self.assertFalse([
x for x in counters if x.urn in [
monitoring_infos.ELEMENT_COUNT_URN,
monitoring_infos.SAMPLED_BYTE_SIZE_URN
] and monitoring_infos.PCOLLECTION_LABEL not in x.labels
])
try:
labels = {
monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_1'
}
self.assert_has_counter(
counters, monitoring_infos.ELEMENT_COUNT_URN, labels, 1)
# Create output.
labels = {
monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_3'
}
self.assert_has_counter(
counters,
monitoring_infos.ELEMENT_COUNT_URN,
labels,
num_source_elems)
self.assert_has_distribution(
counters,
monitoring_infos.SAMPLED_BYTE_SIZE_URN,
labels,
min=hamcrest.greater_than(0),
max=hamcrest.greater_than(0),
sum=hamcrest.greater_than(0),
count=hamcrest.greater_than(0))
# GenerateTwoOutputs, main output.
labels = {
monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_4'
}
self.assert_has_counter(
counters,
monitoring_infos.ELEMENT_COUNT_URN,
labels,
num_source_elems)
self.assert_has_distribution(
counters,
monitoring_infos.SAMPLED_BYTE_SIZE_URN,
labels,
min=hamcrest.greater_than(0),
max=hamcrest.greater_than(0),
sum=hamcrest.greater_than(0),
count=hamcrest.greater_than(0))
# GenerateTwoOutputs, "SecondOutput" output.
labels = {
monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_5'
}
self.assert_has_counter(
counters,
monitoring_infos.ELEMENT_COUNT_URN,
labels,
2 * num_source_elems)
self.assert_has_distribution(
counters,
monitoring_infos.SAMPLED_BYTE_SIZE_URN,
labels,
min=hamcrest.greater_than(0),
max=hamcrest.greater_than(0),
sum=hamcrest.greater_than(0),
count=hamcrest.greater_than(0))
# GenerateTwoOutputs, "ThirdOutput" output.
labels = {
monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_6'
}
self.assert_has_counter(
counters,
monitoring_infos.ELEMENT_COUNT_URN,
labels,
num_source_elems)
self.assert_has_distribution(
counters,
monitoring_infos.SAMPLED_BYTE_SIZE_URN,
labels,
min=hamcrest.greater_than(0),
max=hamcrest.greater_than(0),
sum=hamcrest.greater_than(0),
count=hamcrest.greater_than(0))
# Skipping other pcollections due to non-deterministic naming for multiple
# outputs.
# Flatten/Read, main output.
labels = {
monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_7'
}
self.assert_has_counter(
counters,
monitoring_infos.ELEMENT_COUNT_URN,
labels,
4 * num_source_elems)
self.assert_has_distribution(
counters,
monitoring_infos.SAMPLED_BYTE_SIZE_URN,
labels,
min=hamcrest.greater_than(0),
max=hamcrest.greater_than(0),
sum=hamcrest.greater_than(0),
count=hamcrest.greater_than(0))
# PassThrough, main output
labels = {
monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_8'
}
self.assert_has_counter(
counters,
monitoring_infos.ELEMENT_COUNT_URN,
labels,
4 * num_source_elems)
self.assert_has_distribution(
counters,
monitoring_infos.SAMPLED_BYTE_SIZE_URN,
labels,
min=hamcrest.greater_than(0),
max=hamcrest.greater_than(0),
sum=hamcrest.greater_than(0),
count=hamcrest.greater_than(0))
# PassThrough2, main output
labels = {
monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_9'
}
self.assert_has_counter(
counters,
monitoring_infos.ELEMENT_COUNT_URN,
labels,
num_source_elems)
self.assert_has_distribution(
counters,
monitoring_infos.SAMPLED_BYTE_SIZE_URN,
labels,
min=hamcrest.greater_than(0),
max=hamcrest.greater_than(0),
sum=hamcrest.greater_than(0),
count=hamcrest.greater_than(0))
except:
print(res._monitoring_infos_by_stage)
raise
def test_non_user_metrics(self):
p = self.create_pipeline()
pcoll = p | beam.Create(['a', 'zzz'])
# pylint: disable=expression-not-assigned
pcoll | 'MyStep' >> beam.FlatMap(lambda x: None)
res = p.run()
res.wait_until_finish()
result_metrics = res.monitoring_metrics()
all_metrics_via_montoring_infos = result_metrics.query()
def assert_counter_exists(metrics, namespace, name, step):
found = 0
metric_key = MetricKey(step, MetricName(namespace, name))
for m in metrics['counters']:
if m.key == metric_key:
found = found + 1
self.assertEqual(
1, found, "Did not find exactly 1 metric for %s." % metric_key)
urns = [
monitoring_infos.START_BUNDLE_MSECS_URN,
monitoring_infos.PROCESS_BUNDLE_MSECS_URN,
monitoring_infos.FINISH_BUNDLE_MSECS_URN,
monitoring_infos.TOTAL_MSECS_URN,
]
for urn in urns:
split = urn.split(':')
namespace = split[0]
name = ':'.join(split[1:])
assert_counter_exists(
all_metrics_via_montoring_infos,
namespace,
name,
step='Create/Impulse')
assert_counter_exists(
all_metrics_via_montoring_infos, namespace, name, step='MyStep')
# Due to somewhat non-deterministic nature of state sampling and sleep,
# this test is flaky when state duration is low.
# Since increasing state duration significantly would also slow down
# the test suite, we are retrying twice on failure as a mitigation.
@retry(reraise=True, stop=stop_after_attempt(3))
def test_progress_metrics(self):
p = self.create_pipeline()
_ = (
p
| beam.Create([0, 0, 0, 5e-3 * DEFAULT_SAMPLING_PERIOD_MS],
reshuffle=False)
| beam.Map(time.sleep)
| beam.Map(lambda x: ('key', x))
| beam.GroupByKey()
| 'm_out' >> beam.FlatMap(
lambda x: [
1,
2,
3,
4,
5,
beam.pvalue.TaggedOutput('once', x),
beam.pvalue.TaggedOutput('twice', x),
beam.pvalue.TaggedOutput('twice', x)
]))
res = p.run()
res.wait_until_finish()
def has_mi_for_ptransform(mon_infos, ptransform):
for mi in mon_infos:
if ptransform in mi.labels[monitoring_infos.PTRANSFORM_LABEL]:
return True
return False
try:
# Test the new MonitoringInfo monitoring format.
self.assertEqual(3, len(res._monitoring_infos_by_stage))
pregbk_mis, postgbk_mis = [
mi for stage, mi in res._monitoring_infos_by_stage.items() if stage]
if not has_mi_for_ptransform(pregbk_mis, 'Create/Map(decode)'):
# The monitoring infos above are actually unordered. Swap.
pregbk_mis, postgbk_mis = postgbk_mis, pregbk_mis
# pregbk monitoring infos
labels = {
monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_3'
}
self.assert_has_counter(
pregbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=4)
self.assert_has_distribution(
pregbk_mis, monitoring_infos.SAMPLED_BYTE_SIZE_URN, labels)
labels = {
monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_4'
}
self.assert_has_counter(
pregbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=4)
self.assert_has_distribution(
pregbk_mis, monitoring_infos.SAMPLED_BYTE_SIZE_URN, labels)
labels = {monitoring_infos.PTRANSFORM_LABEL: 'Map(sleep)'}
self.assert_has_counter(
pregbk_mis,
monitoring_infos.TOTAL_MSECS_URN,
labels,
ge_value=4 * DEFAULT_SAMPLING_PERIOD_MS)
# postgbk monitoring infos
labels = {
monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_6'
}
self.assert_has_counter(
postgbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=1)
self.assert_has_distribution(
postgbk_mis, monitoring_infos.SAMPLED_BYTE_SIZE_URN, labels)
labels = {
monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_7'
}
self.assert_has_counter(
postgbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=5)
self.assert_has_distribution(
postgbk_mis, monitoring_infos.SAMPLED_BYTE_SIZE_URN, labels)
except:
print(res._monitoring_infos_by_stage)
raise
class FnApiRunnerTestWithGrpc(FnApiRunnerTest):
def create_pipeline(self, is_drain=False):
return beam.Pipeline(
runner=fn_api_runner.FnApiRunner(
default_environment=environments.EmbeddedPythonGrpcEnvironment.
default(),
is_drain=is_drain))
class FnApiRunnerTestWithDisabledCaching(FnApiRunnerTest):
def create_pipeline(self, is_drain=False):
return beam.Pipeline(
runner=fn_api_runner.FnApiRunner(
default_environment=environments.EmbeddedPythonGrpcEnvironment(
state_cache_size=0,
data_buffer_time_limit_ms=0,
capabilities=environments.python_sdk_capabilities(),
artifacts=()),
is_drain=is_drain))
class FnApiRunnerTestWithMultiWorkers(FnApiRunnerTest):
def create_pipeline(self, is_drain=False):
pipeline_options = PipelineOptions(direct_num_workers=2)
p = beam.Pipeline(
runner=fn_api_runner.FnApiRunner(is_drain=is_drain),
options=pipeline_options)
#TODO(https://github.com/apache/beam/issues/19936): Fix these tests.
p._options.view_as(DebugOptions).experiments.remove('beam_fn_api')
return p
def test_metrics(self):
raise unittest.SkipTest("This test is for a single worker only.")
def test_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("This test is for a single worker only.")
def test_draining_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("This test is for a single worker only.")
def test_sdf_with_watermark_tracking(self):
raise unittest.SkipTest("This test is for a single worker only.")
def test_sdf_with_dofn_as_watermark_estimator(self):
raise unittest.SkipTest("This test is for a single worker only.")
def test_register_finalizations(self):
raise unittest.SkipTest("This test is for a single worker only.")
class FnApiRunnerTestWithGrpcAndMultiWorkers(FnApiRunnerTest):
def create_pipeline(self, is_drain=False):
pipeline_options = PipelineOptions(
direct_num_workers=2, direct_running_mode='multi_threading')
p = beam.Pipeline(
runner=fn_api_runner.FnApiRunner(is_drain=is_drain),
options=pipeline_options)
#TODO(https://github.com/apache/beam/issues/19936): Fix these tests.
p._options.view_as(DebugOptions).experiments.remove('beam_fn_api')
return p
def test_metrics(self):
raise unittest.SkipTest("This test is for a single worker only.")
def test_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("This test is for a single worker only.")
def test_draining_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("This test is for a single worker only.")
def test_sdf_with_watermark_tracking(self):
raise unittest.SkipTest("This test is for a single worker only.")
def test_sdf_with_dofn_as_watermark_estimator(self):
raise unittest.SkipTest("This test is for a single worker only.")
def test_register_finalizations(self):
raise unittest.SkipTest("This test is for a single worker only.")
class FnApiRunnerTestWithBundleRepeat(FnApiRunnerTest):
def create_pipeline(self, is_drain=False):
return beam.Pipeline(
runner=fn_api_runner.FnApiRunner(bundle_repeat=3, is_drain=is_drain))
def test_register_finalizations(self):
raise unittest.SkipTest("TODO: Avoid bundle finalizations on repeat.")
class FnApiRunnerTestWithBundleRepeatAndMultiWorkers(FnApiRunnerTest):
def create_pipeline(self, is_drain=False):
pipeline_options = PipelineOptions(direct_num_workers=2)
p = beam.Pipeline(
runner=fn_api_runner.FnApiRunner(bundle_repeat=3, is_drain=is_drain),
options=pipeline_options)
#TODO(https://github.com/apache/beam/issues/19936): Fix these tests.
p._options.view_as(DebugOptions).experiments.remove('beam_fn_api')
return p
def test_register_finalizations(self):
raise unittest.SkipTest("TODO: Avoid bundle finalizations on repeat.")
def test_metrics(self):
raise unittest.SkipTest("This test is for a single worker only.")
def test_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("This test is for a single worker only.")
def test_draining_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("This test is for a single worker only.")
def test_sdf_with_watermark_tracking(self):
raise unittest.SkipTest("This test is for a single worker only.")
def test_sdf_with_dofn_as_watermark_estimator(self):
raise unittest.SkipTest("This test is for a single worker only.")
class FnApiRunnerSplitTest(unittest.TestCase):
def create_pipeline(self, is_drain=False):
# Must be GRPC so we can send data and split requests concurrent
# to the bundle process request.
return beam.Pipeline(
runner=fn_api_runner.FnApiRunner(
default_environment=environments.EmbeddedPythonGrpcEnvironment.
default(),
is_drain=is_drain))
def test_checkpoint(self):
# This split manager will get re-invoked on each smaller split,
# so N times for N elements.
element_counter = ElementCounter()
def split_manager(num_elements):
# Send at least one element so it can make forward progress.
element_counter.reset()
breakpoint = element_counter.set_breakpoint(1)
# Cede control back to the runner so data can be sent.
yield
breakpoint.wait()
# Split as close to current as possible.
split_result = yield 0.0
# Verify we split at exactly the first element.
self.verify_channel_split(split_result, 0, 1)
# Continue processing.
breakpoint.clear()
self.run_split_pipeline(split_manager, list('abc'), element_counter)
def test_split_half(self):
total_num_elements = 25
seen_bundle_sizes = []
element_counter = ElementCounter()
def split_manager(num_elements):
seen_bundle_sizes.append(num_elements)
if num_elements == total_num_elements:
element_counter.reset()
breakpoint = element_counter.set_breakpoint(5)
yield
breakpoint.wait()
# Split the remainder (20, then 10, elements) in half.
split1 = yield 0.5
self.verify_channel_split(split1, 14, 15) # remainder is 15 to end
split2 = yield 0.5
self.verify_channel_split(split2, 9, 10) # remainder is 10 to end
breakpoint.clear()
self.run_split_pipeline(
split_manager, range(total_num_elements), element_counter)
self.assertEqual([25, 15], seen_bundle_sizes)
def run_split_pipeline(self, split_manager, elements, element_counter=None):
with fn_runner.split_manager('Identity', split_manager):
with self.create_pipeline() as p:
res = (
p
| beam.Create(elements)
| beam.Reshuffle()
| 'Identity' >> beam.Map(lambda x: x)
| beam.Map(lambda x: element_counter.increment() or x))
assert_that(res, equal_to(elements))
def run_sdf_checkpoint(self, is_drain=False):
element_counter = ElementCounter()
def split_manager(num_elements):
if num_elements > 0:
element_counter.reset()
breakpoint = element_counter.set_breakpoint(1)
yield
breakpoint.wait()
yield 0
breakpoint.clear()
# Everything should be perfectly split.
elements = [2, 3]
expected_groups = [[(2, 0)], [(2, 1)], [(3, 0)], [(3, 1)], [(3, 2)]]
self.run_sdf_split_pipeline(
split_manager,
elements,
element_counter,
expected_groups,
is_drain=is_drain)
def run_sdf_split_half(self, is_drain=False):
element_counter = ElementCounter()
is_first_bundle = True
def split_manager(num_elements):
nonlocal is_first_bundle
if is_first_bundle and num_elements > 0:
is_first_bundle = False
breakpoint = element_counter.set_breakpoint(1)
yield
breakpoint.wait()
split1 = yield 0.5
split2 = yield 0.5
split3 = yield 0.5
self.verify_channel_split(split1, 0, 1)
self.verify_channel_split(split2, -1, 1)
self.verify_channel_split(split3, -1, 1)
breakpoint.clear()
elements = [4, 4]
expected_groups = [[(4, 0)], [(4, 1)], [(4, 2), (4, 3)], [(4, 0), (4, 1),
(4, 2), (4, 3)]]
self.run_sdf_split_pipeline(
split_manager,
elements,
element_counter,
expected_groups,
is_drain=is_drain)
def run_split_crazy_sdf(self, seed=None, is_drain=False):
if seed is None:
seed = random.randrange(1 << 20)
r = random.Random(seed)
element_counter = ElementCounter()
def split_manager(num_elements):
if num_elements > 0:
element_counter.reset()
wait_for = r.randrange(num_elements)
breakpoint = element_counter.set_breakpoint(wait_for)
yield
breakpoint.wait()
yield r.random()
yield r.random()
breakpoint.clear()
try:
elements = [r.randrange(5, 10) for _ in range(5)]
self.run_sdf_split_pipeline(
split_manager, elements, element_counter, is_drain=is_drain)
except Exception:
_LOGGER.error('test_split_crazy_sdf.seed = %s', seed)
raise
def test_nosplit_sdf(self):
def split_manager(num_elements):
yield
elements = [1, 2, 3]
expected_groups = [[(e, k) for k in range(e)] for e in elements]
self.run_sdf_split_pipeline(
split_manager, elements, ElementCounter(), expected_groups)
def test_checkpoint_sdf(self):
self.run_sdf_checkpoint(is_drain=False)
def test_checkpoint_draining_sdf(self):
self.run_sdf_checkpoint(is_drain=True)
def test_split_half_sdf(self):
self.run_sdf_split_half(is_drain=False)
def test_split_half_draining_sdf(self):
self.run_sdf_split_half(is_drain=True)
def test_split_crazy_sdf(self, seed=None):
self.run_split_crazy_sdf(seed=seed, is_drain=False)
def test_split_crazy_draining_sdf(self, seed=None):
self.run_split_crazy_sdf(seed=seed, is_drain=True)
def run_sdf_split_pipeline(
self,
split_manager,
elements,
element_counter,
expected_groups=None,
is_drain=False):
# Define an SDF that for each input x produces [(x, k) for k in range(x)].
class EnumerateProvider(beam.transforms.core.RestrictionProvider):
def initial_restriction(self, element):
return restriction_trackers.OffsetRange(0, element)
def create_tracker(self, restriction):
return restriction_trackers.OffsetRestrictionTracker(restriction)
def split(self, element, restriction):
# Don't do any initial splitting to simplify test.
return [restriction]
def restriction_size(self, element, restriction):
return restriction.size()
def is_bounded(self):
return True
class EnumerateSdf(beam.DoFn):
def process(
self,
element,
restriction_tracker=beam.DoFn.RestrictionParam(EnumerateProvider())):
to_emit = []
cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
to_emit.append((element, cur))
element_counter.increment()
cur += 1
# Emitting in batches for tighter testing.
yield to_emit
expected = [(e, k) for e in elements for k in range(e)]
with fn_runner.split_manager('SDF', split_manager):
with self.create_pipeline(is_drain=is_drain) as p:
grouped = (
p
| beam.Create(elements, reshuffle=False)
| 'SDF' >> beam.ParDo(EnumerateSdf()))
flat = grouped | beam.FlatMap(lambda x: x)
assert_that(flat, equal_to(expected))
if expected_groups:
assert_that(grouped, equal_to(expected_groups), label='CheckGrouped')
def verify_channel_split(self, split_result, last_primary, first_residual):
self.assertEqual(1, len(split_result.channel_splits), split_result)
channel_split, = split_result.channel_splits
self.assertEqual(last_primary, channel_split.last_primary_element)
self.assertEqual(first_residual, channel_split.first_residual_element)
# There should be a primary and residual application for each element
# not covered above.
self.assertEqual(
first_residual - last_primary - 1,
len(split_result.primary_roots),
split_result.primary_roots)
self.assertEqual(
first_residual - last_primary - 1,
len(split_result.residual_roots),
split_result.residual_roots)
class ElementCounter(object):
"""Used to wait until a certain number of elements are seen."""
def __init__(self):
self._cv = threading.Condition()
self.reset()
def reset(self):
with self._cv:
self._breakpoints = collections.defaultdict(list)
self._count = 0
def increment(self):
with self._cv:
self._count += 1
self._cv.notify_all()
breakpoints = list(self._breakpoints[self._count])
for breakpoint in breakpoints:
breakpoint.wait()
def set_breakpoint(self, value):
with self._cv:
event = threading.Event()
self._breakpoints[value].append(event)
class Breakpoint(object):
@staticmethod
def wait(timeout=10):
with self._cv:
start = time.time()
while self._count < value:
elapsed = time.time() - start
if elapsed > timeout:
raise RuntimeError('Timed out waiting for %s' % value)
self._cv.wait(timeout - elapsed)
@staticmethod
def clear():
event.set()
return Breakpoint()
def __reduce__(self):
# Ensure we get the same element back through a pickling round-trip.
name = uuid.uuid4().hex
_pickled_element_counters[name] = self
return _unpickle_element_counter, (name, )
_pickled_element_counters = {} # type: Dict[str, ElementCounter]
def _unpickle_element_counter(name):
return _pickled_element_counters[name]
class EventRecorder(object):
"""Used to be registered as a callback in bundle finalization.
The reason why records are written into a tmp file is, the in-memory dataset
cannot keep callback records when passing into one DoFn.
"""
def __init__(self, tmp_dir):
self.tmp_dir = os.path.join(tmp_dir, uuid.uuid4().hex)
os.mkdir(self.tmp_dir)
def record(self, content):
file_path = os.path.join(self.tmp_dir, uuid.uuid4().hex + '.txt')
with open(file_path, 'w') as f:
f.write(content)
def events(self):
content = []
record_files = [
f for f in os.listdir(self.tmp_dir)
if os.path.isfile(os.path.join(self.tmp_dir, f))
]
for file in record_files:
with open(os.path.join(self.tmp_dir, file), 'r') as f:
content.append(f.read())
return sorted(content)
def cleanup(self):
shutil.rmtree(self.tmp_dir)
class ExpandStringsProvider(beam.transforms.core.RestrictionProvider):
"""A RestrictionProvider that used for sdf related tests."""
def initial_restriction(self, element):
return restriction_trackers.OffsetRange(0, len(element))
def create_tracker(self, restriction):
return restriction_trackers.OffsetRestrictionTracker(restriction)
def split(self, element, restriction):
desired_bundle_size = restriction.size() // 2
return restriction.split(desired_bundle_size)
def restriction_size(self, element, restriction):
return restriction.size()
class UnboundedOffsetRestrictionTracker(
restriction_trackers.OffsetRestrictionTracker):
def is_bounded(self):
return False
class OffsetRangeProvider(beam.transforms.core.RestrictionProvider):
def __init__(self, use_bounded_offset_range, checkpoint_only=False):
self.use_bounded_offset_range = use_bounded_offset_range
self.checkpoint_only = checkpoint_only
def initial_restriction(self, element):
return restriction_trackers.OffsetRange(0, element)
def create_tracker(self, restriction):
if self.checkpoint_only:
class CheckpointOnlyOffsetRestrictionTracker(
restriction_trackers.OffsetRestrictionTracker):
def try_split(self, unused_fraction_of_remainder):
return super().try_split(0.0)
return CheckpointOnlyOffsetRestrictionTracker(restriction)
if self.use_bounded_offset_range:
return restriction_trackers.OffsetRestrictionTracker(restriction)
return UnboundedOffsetRestrictionTracker(restriction)
def split(self, element, restriction):
return [restriction]
def restriction_size(self, element, restriction):
return restriction.size()
class OffsetRangeProviderWithTruncate(OffsetRangeProvider):
def __init__(self):
super().__init__(True)
def truncate(self, element, restriction):
return restriction_trackers.OffsetRange(
restriction.start, restriction.stop // 2)
class FnApiBasedLullLoggingTest(unittest.TestCase):
def create_pipeline(self):
return beam.Pipeline(
runner=fn_api_runner.FnApiRunner(
default_environment=environments.EmbeddedPythonGrpcEnvironment.
default(),
progress_request_frequency=0.5))
class StateBackedTestElementType(object):
live_element_count = 0
def __init__(self, num_elements, unused):
self.num_elements = num_elements
StateBackedTestElementType.live_element_count += 1
# Due to using state backed iterable, we expect there is a few instances
# alive at any given time.
if StateBackedTestElementType.live_element_count > 5:
raise RuntimeError('Too many live instances.')
def __del__(self):
StateBackedTestElementType.live_element_count -= 1
def __reduce__(self):
return (self.__class__, (self.num_elements, 'x' * self.num_elements))
@pytest.mark.it_validatesrunner
class FnApiBasedStateBackedCoderTest(unittest.TestCase):
def create_pipeline(self):
return beam.Pipeline(
runner=fn_api_runner.FnApiRunner(use_state_iterables=True))
def test_gbk_many_values(self):
with self.create_pipeline() as p:
# The number of integers could be a knob to test against
# different runners' default settings on page size.
VALUES_PER_ELEMENT = 300
NUM_OF_ELEMENTS = 200
r = (
p
| beam.Create([None])
| beam.FlatMap(
lambda x: ((1, StateBackedTestElementType(VALUES_PER_ELEMENT, _))
for _ in range(NUM_OF_ELEMENTS)))
| beam.GroupByKey()
| beam.MapTuple(lambda _, vs: sum(e.num_elements for e in vs)))
assert_that(r, equal_to([VALUES_PER_ELEMENT * NUM_OF_ELEMENTS]))
# TODO(robertwb): Why does pickling break when this is inlined?
class CustomMergingWindowFn(window.WindowFn):
def assign(self, assign_context):
return [
window.IntervalWindow(
assign_context.timestamp, assign_context.timestamp + 1000)
]
def merge(self, merge_context):
evens = [w for w in merge_context.windows if w.start % 2 == 0]
if evens:
merge_context.merge(
evens,
window.IntervalWindow(
min(w.start for w in evens), max(w.end for w in evens)))
def get_window_coder(self):
return coders.IntervalWindowCoder()
class ExpectingSideInputsFn(beam.DoFn):
def __init__(self, name):
self._name = name
def default_label(self):
return self._name
def process(self, element, *side_inputs):
logging.info('Running %s (side inputs: %s)', self._name, side_inputs)
if not all(list(s) for s in side_inputs):
raise ValueError(f'Missing data in side input {side_inputs}')
yield self._name
class ArrayMultiplyDoFn(beam.DoFn):
def process_batch(self, batch: np.ndarray, *unused_args,
**unused_kwargs) -> Iterator[np.ndarray]:
assert isinstance(batch, np.ndarray)
# GeneralPurposeConsumerSet should limit batches to MAX_BATCH_SIZE (4096)
# elements
assert np.size(batch, axis=0) <= 4096
yield batch * 2
# infer_output_type must be defined (when there's no process method),
# otherwise we don't know the input type is the same as output type.
def infer_output_type(self, input_type):
return input_type
class ListPlusOneDoFn(beam.DoFn):
def process_batch(self, batch: List[np.int64], *unused_args,
**unused_kwargs) -> Iterator[List[np.int64]]:
assert isinstance(batch, list)
yield [element + 1 for element in batch]
# infer_output_type must be defined (when there's no process method),
# otherwise we don't know the input type is the same as output type.
def infer_output_type(self, input_type):
return input_type
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()