[BEAM-2240] Always augment exception with step name.
diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py
index c6b1e48..8aa8a8a 100644
--- a/sdks/python/apache_beam/pipeline_test.py
+++ b/sdks/python/apache_beam/pipeline_test.py
@@ -278,7 +278,7 @@
     with self.assertRaises(ValueError):
       with Pipeline() as p:
         # pylint: disable=expression-not-assigned
-        p | Create([ValueError]) | Map(raise_exception)
+        p | Create([ValueError('msg')]) | Map(raise_exception)
 
   # TODO(BEAM-1894).
   # def test_eager_pipeline(self):
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 0aef0a1..86db711 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -20,6 +20,7 @@
 """Worker operations executor."""
 
 import sys
+import traceback
 
 from apache_beam.internal import util
 from apache_beam.metrics.execution import ScopedMetricsContainer
@@ -409,13 +410,22 @@
   def _reraise_augmented(self, exn):
     if getattr(exn, '_tagged_with_step', False) or not self.step_name:
       raise
-    args = exn.args
-    if args and isinstance(args[0], str):
-      args = (args[0] + " [while running '%s']" % self.step_name,) + args[1:]
-      # Poor man's exception chaining.
-      raise type(exn), args, sys.exc_info()[2]
-    else:
-      raise
+    step_annotation = " [while running '%s']" % self.step_name
+    # To emulate exception chaining (not available in Python 2).
+    original_traceback = sys.exc_info()[2]
+    try:
+      # Attempt to construct the same kind of exception
+      # with an augmented message.
+      new_exn = type(exn)(exn.args[0] + step_annotation, *exn.args[1:])
+      new_exn._tagged_with_step = True  # Could raise attribute error.
+    except:  # pylint: disable=bare-except
+      # If anything goes wrong, construct a RuntimeError whose message
+      # records the original exception's type and message.
+      new_exn = RuntimeError(
+          traceback.format_exception_only(type(exn), exn)[-1].strip()
+          + step_annotation)
+      new_exn._tagged_with_step = True
+    raise new_exn, None, original_traceback
 
 
 class _OutputProcessor(object):
diff --git a/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py b/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
index aebd2e1..062e6f9 100644
--- a/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
@@ -41,9 +41,9 @@
     return beam.Pipeline(runner=maptask_executor_runner.MapTaskExecutorRunner())
 
   def test_assert_that(self):
-    with self.assertRaises(BeamAssertException):
+    with self.assertRaisesRegexp(BeamAssertException, 'bad_assert'):
       with self.create_pipeline() as p:
-        assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))
+        assert_that(p | beam.Create(['a', 'b']), equal_to(['a']), 'bad_assert')
 
   def test_create(self):
     with self.create_pipeline() as p:
@@ -204,6 +204,21 @@
              | beam.Map(lambda (k, vs): (k, sorted(vs))))
       assert_that(res, equal_to([('k', [1, 2]), ('k', [100, 101, 102])]))
 
+  def test_errors(self):
+    with self.assertRaises(BaseException) as e_cm:
+      with self.create_pipeline() as p:
+        def raise_error(x):
+          raise RuntimeError('x')
+        # pylint: disable=expression-not-assigned
+        (p
+         | beam.Create(['a', 'b'])
+         | 'StageA' >> beam.Map(lambda x: x)
+         | 'StageB' >> beam.Map(lambda x: x)
+         | 'StageC' >> beam.Map(raise_error)
+         | 'StageD' >> beam.Map(lambda x: x))
+    self.assertIn('StageC', e_cm.exception.args[0])
+    self.assertNotIn('StageB', e_cm.exception.args[0])
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)