blob: 2cf7dff9d57f5a63c451fd25870b09a46d2539e8 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
import logging
import re
import unittest
import grpc
import apache_beam as beam
from apache_beam.coders.coders import FastPrimitivesCoder
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.portability.api import endpoints_pb2
from apache_beam.runners import common
from apache_beam.runners.common import NameContext
from apache_beam.runners.worker import bundle_processor
from apache_beam.runners.worker import log_handler
from apache_beam.runners.worker import operations
from apache_beam.runners.worker import statesampler
from apache_beam.runners.worker.bundle_processor import BeamTransformFactory
from apache_beam.runners.worker.bundle_processor import BundleProcessor
from apache_beam.transforms.window import GlobalWindow
from apache_beam.utils import thread_pool_executor
from apache_beam.utils.windowed_value import WindowedValue
_LOGGER = logging.getLogger(__name__)
@BeamTransformFactory.register_urn('beam:internal:testexn:v1', bytes)
def create_exception_dofn(
factory, transform_id, transform_proto, payload, consumers):
"""Returns a test DoFn that raises the given exception."""
class RaiseException(beam.DoFn):
def __init__(self, msg):
self.msg = msg.decode()
def process(self, _):
raise RuntimeError(self.msg)
return bundle_processor._create_simple_pardo_operation(
factory,
transform_id,
transform_proto,
consumers,
RaiseException(payload))
class TestOperation(operations.Operation):
"""Test operation that forwards its payload to consumers."""
class Spec:
def __init__(self, transform_proto):
self.output_coders = [
FastPrimitivesCoder() for _ in transform_proto.outputs
]
def __init__(
self,
transform_proto,
name_context,
counter_factory,
state_sampler,
consumers,
payload,
):
super().__init__(
name_context,
self.Spec(transform_proto),
counter_factory,
state_sampler)
self.payload = payload
for _, consumer_ops in consumers.items():
for consumer in consumer_ops:
self.add_receiver(consumer, 0)
def start(self):
super().start()
# Not using windowing logic, so just using simple defaults here.
if self.payload:
self.process(
WindowedValue(self.payload, timestamp=0, windows=[GlobalWindow()]))
def process(self, windowed_value):
self.output(windowed_value)
@BeamTransformFactory.register_urn('beam:internal:testop:v1', bytes)
def create_test_op(factory, transform_id, transform_proto, payload, consumers):
return TestOperation(
transform_proto,
common.NameContext(transform_proto.unique_name, transform_id),
factory.counter_factory,
factory.state_sampler,
consumers,
payload)
class BeamFnLoggingServicer(beam_fn_api_pb2_grpc.BeamFnLoggingServicer):
def __init__(self):
self.log_records_received = []
def Logging(self, request_iterator, context):
for log_record in request_iterator:
self.log_records_received.append(log_record)
yield beam_fn_api_pb2.LogControl()
class FnApiLogRecordHandlerTest(unittest.TestCase):
def setUp(self):
self.test_logging_service = BeamFnLoggingServicer()
self.server = grpc.server(thread_pool_executor.shared_unbounded_instance())
beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server(
self.test_logging_service, self.server)
self.test_port = self.server.add_insecure_port('[::]:0')
self.server.start()
self.logging_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
self.logging_service_descriptor.url = 'localhost:%s' % self.test_port
self.fn_log_handler = log_handler.FnApiLogRecordHandler(
self.logging_service_descriptor)
logging.getLogger().setLevel(logging.INFO)
logging.getLogger().addHandler(self.fn_log_handler)
def tearDown(self):
# wait upto 5 seconds.
self.server.stop(5)
def _verify_fn_log_handler(self, num_log_entries):
msg = 'Testing fn logging'
_LOGGER.debug('Debug Message 1')
for idx in range(num_log_entries):
_LOGGER.info('%s: %s', msg, idx)
_LOGGER.debug('Debug Message 2')
# Wait for logs to be sent to server.
self.fn_log_handler.close()
num_received_log_entries = 0
for outer in self.test_logging_service.log_records_received:
for log_entry in outer.log_entries:
self.assertEqual(
beam_fn_api_pb2.LogEntry.Severity.INFO, log_entry.severity)
self.assertEqual(
'%s: %s' % (msg, num_received_log_entries), log_entry.message)
self.assertTrue(
re.match(r'.*log_handler_test.py:\d+', log_entry.log_location),
log_entry.log_location)
self.assertGreater(log_entry.timestamp.seconds, 0)
self.assertGreaterEqual(log_entry.timestamp.nanos, 0)
num_received_log_entries += 1
self.assertEqual(num_received_log_entries, num_log_entries)
def assertContains(self, haystack, needle):
self.assertTrue(
needle in haystack, 'Expected %r to contain %r.' % (haystack, needle))
def test_exc_info(self):
try:
raise ValueError('some message')
except ValueError:
_LOGGER.error('some error', exc_info=True)
self.fn_log_handler.close()
log_entry = self.test_logging_service.log_records_received[0].log_entries[0]
self.assertContains(log_entry.message, 'some error')
self.assertContains(log_entry.trace, 'some message')
self.assertContains(log_entry.trace, 'log_handler_test.py')
def test_format_bad_message(self):
# We specifically emit to the handler directly since we don't want to emit
# to all handlers in general since we know that this record will raise an
# exception during formatting.
self.fn_log_handler.emit(
logging.LogRecord(
'name',
logging.ERROR,
'pathname',
777,
'TestLog %d', (None, ),
exc_info=None))
self.fn_log_handler.close()
log_entry = self.test_logging_service.log_records_received[0].log_entries[0]
self.assertContains(
log_entry.message,
"Failed to format 'TestLog %d' with args '(None,)' during logging.")
def test_context(self):
try:
with statesampler.instruction_id('A'):
tracker = statesampler.for_test()
with tracker.scoped_state(NameContext('name', 'tid'), 'stage'):
_LOGGER.info('message a')
with statesampler.instruction_id('B'):
_LOGGER.info('message b')
_LOGGER.info('message c')
self.fn_log_handler.close()
a, b, c = sum(
[list(logs.log_entries)
for logs in self.test_logging_service.log_records_received], [])
self.assertEqual(a.instruction_id, 'A')
self.assertEqual(b.instruction_id, 'B')
self.assertEqual(c.instruction_id, '')
self.assertEqual(a.transform_id, 'tid')
self.assertEqual(b.transform_id, '')
self.assertEqual(c.transform_id, '')
finally:
statesampler.set_current_tracker(None)
def test_extracts_transform_id_during_exceptions(self):
"""Tests that transform ids are captured during user code exceptions."""
descriptor = beam_fn_api_pb2.ProcessBundleDescriptor()
# Boiler plate for the DoFn.
WINDOWING_ID = 'window'
WINDOW_CODER_ID = 'cw'
window = descriptor.windowing_strategies[WINDOWING_ID]
window.window_fn.urn = common_urns.global_windows.urn
window.window_coder_id = WINDOW_CODER_ID
window.trigger.default.SetInParent()
window_coder = descriptor.coders[WINDOW_CODER_ID]
window_coder.spec.urn = common_urns.StandardCoders.Enum.GLOBAL_WINDOW.urn
# Input collection to the exception raising DoFn.
INPUT_PCOLLECTION_ID = 'pc-in'
INPUT_CODER_ID = 'c-in'
descriptor.pcollections[
INPUT_PCOLLECTION_ID].unique_name = INPUT_PCOLLECTION_ID
descriptor.pcollections[INPUT_PCOLLECTION_ID].coder_id = INPUT_CODER_ID
descriptor.pcollections[
INPUT_PCOLLECTION_ID].windowing_strategy_id = WINDOWING_ID
descriptor.coders[
INPUT_CODER_ID].spec.urn = common_urns.StandardCoders.Enum.BYTES.urn
# Output collection to the exception raising DoFn.
OUTPUT_PCOLLECTION_ID = 'pc-out'
OUTPUT_CODER_ID = 'c-out'
descriptor.pcollections[
OUTPUT_PCOLLECTION_ID].unique_name = OUTPUT_PCOLLECTION_ID
descriptor.pcollections[OUTPUT_PCOLLECTION_ID].coder_id = OUTPUT_CODER_ID
descriptor.pcollections[
OUTPUT_PCOLLECTION_ID].windowing_strategy_id = WINDOWING_ID
descriptor.coders[
OUTPUT_CODER_ID].spec.urn = common_urns.StandardCoders.Enum.BYTES.urn
# Add a simple transform to inject an element into the fake pipeline.
TEST_OP_TRANSFORM_ID = 'test_op'
test_transform = descriptor.transforms[TEST_OP_TRANSFORM_ID]
test_transform.outputs['None'] = INPUT_PCOLLECTION_ID
test_transform.spec.urn = 'beam:internal:testop:v1'
test_transform.spec.payload = b'hello, world!'
# Add the DoFn to create an exception.
TEST_EXCEPTION_TRANSFORM_ID = 'test_transform'
test_transform = descriptor.transforms[TEST_EXCEPTION_TRANSFORM_ID]
test_transform.inputs['0'] = INPUT_PCOLLECTION_ID
test_transform.outputs['None'] = OUTPUT_PCOLLECTION_ID
test_transform.spec.urn = 'beam:internal:testexn:v1'
test_transform.spec.payload = b'expected exception'
# Create and process a fake bundle. The instruction id doesn't matter
# here.
processor = BundleProcessor(set(), descriptor, None, None)
with self.assertRaisesRegex(RuntimeError, 'expected exception'):
processor.process_bundle('instruction_id')
self.fn_log_handler.close()
logs = [
log for logs in self.test_logging_service.log_records_received
for log in logs.log_entries
]
actual_log = logs[0]
self.assertEqual(
actual_log.severity, beam_fn_api_pb2.LogEntry.Severity.ERROR)
self.assertTrue('expected exception' in actual_log.message)
self.assertEqual(actual_log.transform_id, 'test_transform')
# Test cases.
data = {
'one_batch': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE - 47,
'exact_multiple': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE,
'multi_batch': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE * 3 + 47
}
def _create_test(name, num_logs):
setattr(
FnApiLogRecordHandlerTest,
'test_%s' % name,
lambda self: self._verify_fn_log_handler(num_logs))
for test_name, num_logs_entries in data.items():
_create_test(test_name, num_logs_entries)
if __name__ == '__main__':
unittest.main()