[BEAM-2240] Always augment exception with step name.
diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py
index c6b1e48..8aa8a8a 100644
--- a/sdks/python/apache_beam/pipeline_test.py
+++ b/sdks/python/apache_beam/pipeline_test.py
@@ -278,7 +278,7 @@
with self.assertRaises(ValueError):
with Pipeline() as p:
# pylint: disable=expression-not-assigned
- p | Create([ValueError]) | Map(raise_exception)
+ p | Create([ValueError('msg')]) | Map(raise_exception)
# TODO(BEAM-1894).
# def test_eager_pipeline(self):
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 0aef0a1..86db711 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -20,6 +20,7 @@
"""Worker operations executor."""
import sys
+import traceback
from apache_beam.internal import util
from apache_beam.metrics.execution import ScopedMetricsContainer
@@ -409,13 +410,22 @@
def _reraise_augmented(self, exn):
if getattr(exn, '_tagged_with_step', False) or not self.step_name:
raise
- args = exn.args
- if args and isinstance(args[0], str):
- args = (args[0] + " [while running '%s']" % self.step_name,) + args[1:]
- # Poor man's exception chaining.
- raise type(exn), args, sys.exc_info()[2]
- else:
- raise
+ step_annotation = " [while running '%s']" % self.step_name
+ # To emulate exception chaining (not available in Python 2).
+ original_traceback = sys.exc_info()[2]
+ try:
+ # Attempt to construct the same kind of exception
+ # with an augmented message.
+ new_exn = type(exn)(exn.args[0] + step_annotation, *exn.args[1:])
+ new_exn._tagged_with_step = True # Could raise attribute error.
+ except: # pylint: disable=bare-except
+ # If anything goes wrong, construct a RuntimeError whose message
+ # records the original exception's type and message.
+ new_exn = RuntimeError(
+ traceback.format_exception_only(type(exn), exn)[-1].strip()
+ + step_annotation)
+ new_exn._tagged_with_step = True
+ raise new_exn, None, original_traceback
class _OutputProcessor(object):
diff --git a/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py b/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
index aebd2e1..062e6f9 100644
--- a/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
@@ -41,9 +41,9 @@
return beam.Pipeline(runner=maptask_executor_runner.MapTaskExecutorRunner())
def test_assert_that(self):
- with self.assertRaises(BeamAssertException):
+ with self.assertRaisesRegexp(BeamAssertException, 'bad_assert'):
with self.create_pipeline() as p:
- assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))
+ assert_that(p | beam.Create(['a', 'b']), equal_to(['a']), 'bad_assert')
def test_create(self):
with self.create_pipeline() as p:
@@ -204,6 +204,21 @@
| beam.Map(lambda (k, vs): (k, sorted(vs))))
assert_that(res, equal_to([('k', [1, 2]), ('k', [100, 101, 102])]))
+ def test_errors(self):
+ with self.assertRaises(BaseException) as e_cm:
+ with self.create_pipeline() as p:
+ def raise_error(x):
+ raise RuntimeError('x')
+ # pylint: disable=expression-not-assigned
+ (p
+ | beam.Create(['a', 'b'])
+ | 'StageA' >> beam.Map(lambda x: x)
+ | 'StageB' >> beam.Map(lambda x: x)
+ | 'StageC' >> beam.Map(raise_error)
+ | 'StageD' >> beam.Map(lambda x: x))
+ self.assertIn('StageC', e_cm.exception.args[0])
+ self.assertNotIn('StageB', e_cm.exception.args[0])
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)