blob: 47b6cca880d367a0a8e20ca2874664fb9f3b1c30 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
import sys
import time
import traceback
import unittest
from typing import Any
from typing import List
from typing import Optional
from apache_beam.coders import FastPrimitivesCoder
from apache_beam.coders import WindowedValueCoder
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.runners.worker.data_sampler import DataSampler
from apache_beam.runners.worker.data_sampler import OutputSampler
from apache_beam.transforms.window import GlobalWindow
from apache_beam.utils.windowed_value import WindowedValue
MAIN_TRANSFORM_ID = 'transform'
MAIN_PCOLLECTION_ID = 'pcoll'
PRIMITIVES_CODER = FastPrimitivesCoder()
class DataSamplerTest(unittest.TestCase):
def make_test_descriptor(
self,
outputs: Optional[List[str]] = None,
transforms: Optional[List[str]] = None
) -> beam_fn_api_pb2.ProcessBundleDescriptor:
outputs = outputs or [MAIN_PCOLLECTION_ID]
transforms = transforms or [MAIN_TRANSFORM_ID]
descriptor = beam_fn_api_pb2.ProcessBundleDescriptor()
for transform_id in transforms:
transform = descriptor.transforms[transform_id]
for output in outputs:
transform.outputs[output] = output
return descriptor
def setUp(self):
self.data_sampler = DataSampler.create(
PipelineOptions(experiments=[DataSampler._ENABLE_DATA_SAMPLING]),
sample_every_sec=0.1)
def tearDown(self):
self.data_sampler.stop()
def primitives_coder_factory(self, _):
return PRIMITIVES_CODER
def gen_sample(
self,
data_sampler: DataSampler,
element: Any,
output_index: int,
transform_id: str = MAIN_TRANSFORM_ID):
"""Generates a sample for the given transform's output."""
element_sampler = self.data_sampler.sampler_for_output(
transform_id, output_index).element_sampler
element_sampler.el = element
element_sampler.has_element = True
def test_single_output(self):
"""Simple test for a single sample."""
descriptor = self.make_test_descriptor()
self.data_sampler.initialize_samplers(
MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory)
self.gen_sample(self.data_sampler, 'a', output_index=0)
expected_sample = beam_fn_api_pb2.SampleDataResponse(
element_samples={
MAIN_PCOLLECTION_ID: beam_fn_api_pb2.SampleDataResponse.ElementList(
elements=[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('a'))
])
})
samples = self.data_sampler.wait_for_samples([MAIN_PCOLLECTION_ID])
self.assertEqual(samples, expected_sample)
def test_not_initialized(self):
"""Tests that transforms fail gracefully if not properly initialized."""
with self.assertLogs() as cm:
self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, 0)
self.assertRegex(cm.output[0], 'Out-of-bounds access.*')
def map_outputs_to_indices(
self, outputs, descriptor, transform_id=MAIN_TRANSFORM_ID):
tag_list = list(descriptor.transforms[transform_id].outputs)
return {output: tag_list.index(output) for output in outputs}
def test_sampler_mapping(self):
"""Tests that the ElementSamplers are created for the correct output."""
# Initialize the DataSampler with the following outputs. The order here may
# get shuffled when inserting into the descriptor.
pcollection_ids = ['o0', 'o1', 'o2']
descriptor = self.make_test_descriptor(outputs=pcollection_ids)
samplers = self.data_sampler.initialize_samplers(
MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory)
# Create a map from the PCollection id to the index into the transform
# output. This mirrors what happens when operators are created. The index of
# an output is where in the PTransform.outputs it is located (when the map
# is converted to a list).
outputs = self.map_outputs_to_indices(pcollection_ids, descriptor)
# Assert that the mapping is correct, i.e. that we can go from the
# PCollection id -> output index and that this is the same as the created
# samplers.
index = outputs['o0']
self.assertEqual(
self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID,
index).element_sampler,
samplers[index].element_sampler)
index = outputs['o1']
self.assertEqual(
self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID,
index).element_sampler,
samplers[index].element_sampler)
index = outputs['o2']
self.assertEqual(
self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID,
index).element_sampler,
samplers[index].element_sampler)
def test_multiple_outputs(self):
"""Tests that multiple PCollections have their own sampler."""
pcollection_ids = ['o0', 'o1', 'o2']
descriptor = self.make_test_descriptor(outputs=pcollection_ids)
outputs = self.map_outputs_to_indices(pcollection_ids, descriptor)
self.data_sampler.initialize_samplers(
MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory)
self.gen_sample(self.data_sampler, 'a', output_index=outputs['o0'])
self.gen_sample(self.data_sampler, 'b', output_index=outputs['o1'])
self.gen_sample(self.data_sampler, 'c', output_index=outputs['o2'])
samples = self.data_sampler.wait_for_samples(['o0', 'o1', 'o2'])
expected_samples = beam_fn_api_pb2.SampleDataResponse(
element_samples={
'o0': beam_fn_api_pb2.SampleDataResponse.ElementList(
elements=[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('a'))
]),
'o1': beam_fn_api_pb2.SampleDataResponse.ElementList(
elements=[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('b'))
]),
'o2': beam_fn_api_pb2.SampleDataResponse.ElementList(
elements=[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('c'))
]),
})
self.assertEqual(samples, expected_samples)
def test_multiple_transforms(self):
"""Test that multiple transforms with the same PCollections can be sampled.
"""
# Initialize two transform both with the same two outputs.
pcollection_ids = ['o0', 'o1']
descriptor = self.make_test_descriptor(
outputs=pcollection_ids, transforms=['t0', 't1'])
t0_outputs = self.map_outputs_to_indices(
pcollection_ids, descriptor, transform_id='t0')
t1_outputs = self.map_outputs_to_indices(
pcollection_ids, descriptor, transform_id='t1')
self.data_sampler.initialize_samplers(
't0', descriptor, self.primitives_coder_factory)
self.data_sampler.initialize_samplers(
't1', descriptor, self.primitives_coder_factory)
# The OutputSampler is on a different thread so we don't test the same
# PCollections to ensure that no data race occurs.
self.gen_sample(
self.data_sampler,
'a',
output_index=t0_outputs['o0'],
transform_id='t0')
self.gen_sample(
self.data_sampler,
'd',
output_index=t1_outputs['o1'],
transform_id='t1')
expected_samples = beam_fn_api_pb2.SampleDataResponse(
element_samples={
'o0': beam_fn_api_pb2.SampleDataResponse.ElementList(
elements=[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('a'))
]),
'o1': beam_fn_api_pb2.SampleDataResponse.ElementList(
elements=[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('d'))
]),
})
samples = self.data_sampler.wait_for_samples(['o0', 'o1'])
self.assertEqual(samples, expected_samples)
self.gen_sample(
self.data_sampler,
'b',
output_index=t0_outputs['o1'],
transform_id='t0')
self.gen_sample(
self.data_sampler,
'c',
output_index=t1_outputs['o0'],
transform_id='t1')
expected_samples = beam_fn_api_pb2.SampleDataResponse(
element_samples={
'o0': beam_fn_api_pb2.SampleDataResponse.ElementList(
elements=[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('c'))
]),
'o1': beam_fn_api_pb2.SampleDataResponse.ElementList(
elements=[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('b'))
]),
})
samples = self.data_sampler.wait_for_samples(['o0', 'o1'])
self.assertEqual(samples, expected_samples)
def test_sample_filters_single_pcollection_ids(self):
"""Tests the samples can be filtered based on a single pcollection id."""
pcollection_ids = ['o0', 'o1', 'o2']
descriptor = self.make_test_descriptor(outputs=pcollection_ids)
outputs = self.map_outputs_to_indices(pcollection_ids, descriptor)
self.data_sampler.initialize_samplers(
MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory)
self.gen_sample(self.data_sampler, 'a', output_index=outputs['o0'])
self.gen_sample(self.data_sampler, 'b', output_index=outputs['o1'])
self.gen_sample(self.data_sampler, 'c', output_index=outputs['o2'])
samples = self.data_sampler.wait_for_samples(['o0'])
expected_samples = beam_fn_api_pb2.SampleDataResponse(
element_samples={
'o0': beam_fn_api_pb2.SampleDataResponse.ElementList(
elements=[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('a'))
]),
})
self.assertEqual(samples, expected_samples)
samples = self.data_sampler.wait_for_samples(['o1'])
expected_samples = beam_fn_api_pb2.SampleDataResponse(
element_samples={
'o1': beam_fn_api_pb2.SampleDataResponse.ElementList(
elements=[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('b'))
]),
})
self.assertEqual(samples, expected_samples)
samples = self.data_sampler.wait_for_samples(['o2'])
expected_samples = beam_fn_api_pb2.SampleDataResponse(
element_samples={
'o2': beam_fn_api_pb2.SampleDataResponse.ElementList(
elements=[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('c'))
]),
})
self.assertEqual(samples, expected_samples)
def test_sample_filters_multiple_pcollection_ids(self):
"""Tests the samples can be filtered based on a multiple pcollection ids."""
pcollection_ids = ['o0', 'o1', 'o2']
descriptor = self.make_test_descriptor(outputs=pcollection_ids)
outputs = self.map_outputs_to_indices(pcollection_ids, descriptor)
self.data_sampler.initialize_samplers(
MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory)
self.gen_sample(self.data_sampler, 'a', output_index=outputs['o0'])
self.gen_sample(self.data_sampler, 'b', output_index=outputs['o1'])
self.gen_sample(self.data_sampler, 'c', output_index=outputs['o2'])
samples = self.data_sampler.wait_for_samples(['o0', 'o2'])
expected_samples = beam_fn_api_pb2.SampleDataResponse(
element_samples={
'o0': beam_fn_api_pb2.SampleDataResponse.ElementList(
elements=[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('a'))
]),
'o2': beam_fn_api_pb2.SampleDataResponse.ElementList(
elements=[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('c'))
]),
})
self.assertEqual(samples, expected_samples)
def test_can_sample_exceptions(self):
"""Tests that exceptions sampled can be queried by the DataSampler."""
descriptor = self.make_test_descriptor()
self.data_sampler.initialize_samplers(
MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory)
sampler = self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, 0)
exc_info = None
try:
raise Exception('test')
except Exception:
exc_info = sys.exc_info()
sampler.sample_exception('a', exc_info, MAIN_TRANSFORM_ID, 'instid')
samples = self.data_sampler.wait_for_samples([MAIN_PCOLLECTION_ID])
self.assertGreater(len(samples.element_samples), 0)
def test_create_experiments(self):
"""Tests that the experiments correctly make the DataSampler."""
enable_exception_exp = DataSampler._ENABLE_ALWAYS_ON_EXCEPTION_SAMPLING
disable_exception_exp = DataSampler._DISABLE_ALWAYS_ON_EXCEPTION_SAMPLING
enable_sampling_exp = DataSampler._ENABLE_DATA_SAMPLING
self.assertIsNone(DataSampler.create(PipelineOptions()))
exp = [disable_exception_exp]
self.assertIsNone(DataSampler.create(PipelineOptions(experiments=exp)))
exp = [enable_exception_exp, disable_exception_exp]
self.assertIsNone(DataSampler.create(PipelineOptions(experiments=exp)))
exp = [enable_exception_exp]
self.assertIsNotNone(DataSampler.create(PipelineOptions(experiments=exp)))
exp = [enable_sampling_exp]
self.assertIsNotNone(DataSampler.create(PipelineOptions(experiments=exp)))
exp = [enable_sampling_exp, enable_exception_exp, disable_exception_exp]
self.assertIsNotNone(DataSampler.create(PipelineOptions(experiments=exp)))
def test_samples_all_with_both_experiments(self):
"""Tests that the using both sampling experiments samples everything."""
self.data_sampler = DataSampler.create(
PipelineOptions(
experiments=[
DataSampler._ENABLE_DATA_SAMPLING,
DataSampler._ENABLE_ALWAYS_ON_EXCEPTION_SAMPLING
]),
sample_every_sec=0.1)
# Create a descriptor with one transform with two outputs, 'a' and 'b'.
descriptor = self.make_test_descriptor(outputs=['a', 'b'])
self.data_sampler.initialize_samplers(
MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory)
# Get the samples for the two outputs.
# N.B. the order of the samplers is not guaranteed due to Protobuf not
# guaranteeing map iteration order.
first_sampler = self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, 0)
second_sampler = self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, 1)
# Sample an exception for the output 'a', this will show up in the final
# samples response.
exc_info = None
try:
raise Exception('test')
except Exception:
exc_info = sys.exc_info()
first_sampler.sample_exception(
'first', exc_info, MAIN_TRANSFORM_ID, 'instid')
# Sample a normal element for the output 'b', this will not show up in the
# final samples response.
second_sampler.element_sampler.el = 'second'
second_sampler.element_sampler.has_element = True
samples = self.data_sampler.wait_for_samples(['a', 'b'])
self.assertEqual(len(samples.element_samples), 2)
sample_elements = list(
s.elements[0] for s in samples.element_samples.values())
num_exceptions = sum(
1 for element in sample_elements if element.HasField('exception'))
self.assertEqual(
num_exceptions,
1,
"Only one of the samples should have an exception, found: {}".format(
sample_elements))
def test_only_sample_exceptions(self):
"""Tests that the exception sampling experiment only samples exceptions."""
self.data_sampler = DataSampler.create(
PipelineOptions(
experiments=[DataSampler._ENABLE_ALWAYS_ON_EXCEPTION_SAMPLING]),
sample_every_sec=0.1)
# Create a descriptor with one transform with two outputs, 'a' and 'b'.
descriptor = self.make_test_descriptor(outputs=['a', 'b'])
self.data_sampler.initialize_samplers(
MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory)
# Get the samples for the two outputs.
# N.B. the order of the samplers is not guaranteed due to Protobuf not
# guaranteeing map iteration order.
first_sampler = self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, 0)
second_sampler = self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, 1)
# Sample an exception for the output 'a', this will show up in the final
# samples response.
exc_info = None
try:
raise Exception('test')
except Exception:
exc_info = sys.exc_info()
first_sampler.sample_exception(
'first', exc_info, MAIN_TRANSFORM_ID, 'instid')
# Sample a normal element for the output 'b', this will not show up in the
# final samples response.
second_sampler.element_sampler.el = 'second'
second_sampler.element_sampler.has_element = True
samples = self.data_sampler.wait_for_samples([])
self.assertEqual(len(samples.element_samples), 1)
value = list(samples.element_samples.values())[0]
self.assertIsNotNone(value.elements[0].exception)
class OutputSamplerTest(unittest.TestCase):
def tearDown(self):
self.sampler.stop()
def wait_for_samples(self, output_sampler: OutputSampler, expected_num: int):
"""Waits for the expected number of samples for the given sampler."""
now = time.time()
end = now + 30
while now < end:
time.sleep(0.1)
now = time.time()
samples = output_sampler.flush(clear=False)
if not samples:
continue
if len(samples) == expected_num:
return samples
self.assertLess(now, end, 'Timed out waiting for samples')
def test_can_sample(self):
"""Tests that the underlying timer can sample."""
self.sampler = OutputSampler(PRIMITIVES_CODER, sample_every_sec=0.05)
element_sampler = self.sampler.element_sampler
element_sampler.el = 'a'
element_sampler.has_element = True
self.wait_for_samples(self.sampler, expected_num=1)
self.assertEqual(
self.sampler.flush(),
[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('a'))
])
def test_acts_like_circular_buffer(self):
"""Tests that the buffer overwrites old samples."""
self.sampler = OutputSampler(
PRIMITIVES_CODER, max_samples=2, sample_every_sec=0)
element_sampler = self.sampler.element_sampler
for i in range(10):
element_sampler.el = i
element_sampler.has_element = True
self.sampler.sample()
self.assertEqual(
self.sampler.flush(),
[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested(i)) for i in (8, 9)
])
def test_samples_multiple_times(self):
"""Tests that the underlying timer repeats."""
self.sampler = OutputSampler(
PRIMITIVES_CODER, max_samples=10, sample_every_sec=0.05)
# Always samples the first ten.
for i in range(10):
self.sampler.element_sampler.el = i
self.sampler.element_sampler.has_element = True
self.wait_for_samples(self.sampler, i + 1)
self.assertEqual(
self.sampler.flush(),
[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested(i)) for i in range(10)
])
def test_can_sample_windowed_value(self):
"""Tests that values with WindowedValueCoders are sampled wholesale."""
coder = WindowedValueCoder(FastPrimitivesCoder())
value = WindowedValue('Hello, World!', 0, [GlobalWindow()])
self.sampler = OutputSampler(coder, sample_every_sec=0)
element_sampler = self.sampler.element_sampler
element_sampler.el = value
element_sampler.has_element = True
self.sampler.sample()
self.assertEqual(
self.sampler.flush(),
[beam_fn_api_pb2.SampledElement(element=coder.encode_nested(value))])
def test_can_sample_non_windowed_value(self):
"""Tests that windowed values with WindowedValueCoders sample only the
value.
This is important because the Python SDK wraps all values in a WindowedValue
even if the coder is not a WindowedValueCoder. In this case, the value must
be retrieved from the WindowedValue to match the correct coder.
"""
value = WindowedValue('Hello, World!', 0, [GlobalWindow()])
self.sampler = OutputSampler(PRIMITIVES_CODER, sample_every_sec=0)
element_sampler = self.sampler.element_sampler
element_sampler.el = value
element_sampler.has_element = True
self.sampler.sample()
self.assertEqual(
self.sampler.flush(),
[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('Hello, World!'))
])
def test_can_sample_exceptions(self):
"""Tests that exceptions are sampled."""
val = WindowedValue('Hello, World!', 0, [GlobalWindow()])
exc_info = None
try:
raise Exception('test')
except Exception:
exc_info = sys.exc_info()
err_string = ''.join(traceback.format_exception(*exc_info))
self.sampler = OutputSampler(PRIMITIVES_CODER, sample_every_sec=0)
self.sampler.sample_exception(
el=val, exc_info=exc_info, transform_id='tid', instruction_id='instid')
self.assertEqual(
self.sampler.flush(),
[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('Hello, World!'),
exception=beam_fn_api_pb2.SampledElement.Exception(
instruction_id='instid',
transform_id='tid',
error=err_string))
])
def test_can_sample_multiple_exceptions(self):
"""Tests that multiple exceptions in the same PCollection are sampled."""
exc_info = None
try:
raise Exception('test')
except Exception:
exc_info = sys.exc_info()
err_string = ''.join(traceback.format_exception(*exc_info))
self.sampler = OutputSampler(PRIMITIVES_CODER, sample_every_sec=0)
self.sampler.sample_exception(
el='a', exc_info=exc_info, transform_id='tid', instruction_id='instid')
self.sampler.sample_exception(
el='b', exc_info=exc_info, transform_id='tid', instruction_id='instid')
self.assertEqual(
self.sampler.flush(),
[
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('a'),
exception=beam_fn_api_pb2.SampledElement.Exception(
instruction_id='instid',
transform_id='tid',
error=err_string)),
beam_fn_api_pb2.SampledElement(
element=PRIMITIVES_CODER.encode_nested('b'),
exception=beam_fn_api_pb2.SampledElement.Exception(
instruction_id='instid',
transform_id='tid',
error=err_string)),
])
if __name__ == '__main__':
unittest.main()