[Python] Fix: Propagate resource hints through with_exception_handling (#36090)
* Implement two-way propagation for resource hint, fix Python with_exception_handling + JAX-on-Beam = pipeline failure
* Add unit test for resource hint propagation in ParDo.
* Propagate resource hint in a more intuiative way
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 1bfc732..2304faf 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -1678,7 +1678,8 @@
timeout,
error_handler,
on_failure_callback,
- allow_unsafe_userstate_in_process)
+ allow_unsafe_userstate_in_process,
+ self.get_resource_hints())
def with_error_handler(self, error_handler, **exception_handling_kwargs):
"""An alias for `with_exception_handling(error_handler=error_handler, ...)`
@@ -2284,7 +2285,8 @@
timeout,
error_handler,
on_failure_callback,
- allow_unsafe_userstate_in_process):
+ allow_unsafe_userstate_in_process,
+ resource_hints):
if partial and use_subprocess:
raise ValueError('partial and use_subprocess are mutually incompatible.')
self._fn = fn
@@ -2301,6 +2303,7 @@
self._error_handler = error_handler
self._on_failure_callback = on_failure_callback
self._allow_unsafe_userstate_in_process = allow_unsafe_userstate_in_process
+ self._resource_hints = resource_hints
def expand(self, pcoll):
if self._allow_unsafe_userstate_in_process:
@@ -2317,17 +2320,23 @@
wrapped_fn = _TimeoutDoFn(self._fn, timeout=self._timeout)
else:
wrapped_fn = self._fn
- result = pcoll | ParDo(
+ pardo = ParDo(
_ExceptionHandlingWrapperDoFn(
wrapped_fn,
self._dead_letter_tag,
self._exc_class,
self._partial,
self._on_failure_callback,
- self._allow_unsafe_userstate_in_process),
+ self._allow_unsafe_userstate_in_process,
+ ),
*self._args,
- **self._kwargs).with_outputs(
- self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True)
+ **self._kwargs,
+ )
+ # This is the fix: propagate hints.
+ pardo.get_resource_hints().update(self._resource_hints)
+
+ result = pcoll | pardo.with_outputs(
+ self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True)
#TODO(BEAM-18957): Fix when type inference supports tagged outputs.
result[self._main_tag].element_type = self._fn.infer_output_type(
pcoll.element_type)
diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py
index 3e5e767..0d680c9 100644
--- a/sdks/python/apache_beam/transforms/core_test.py
+++ b/sdks/python/apache_beam/transforms/core_test.py
@@ -30,6 +30,7 @@
from apache_beam.coders import coders
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
+from apache_beam.transforms.resources import ResourceHint
from apache_beam.transforms.userstate import BagStateSpec
from apache_beam.transforms.userstate import ReadModifyWriteStateSpec
from apache_beam.transforms.userstate import TimerSpec
@@ -416,6 +417,94 @@
assert_that(good, equal_to([0, 1, 2]), 'good')
assert_that(bad_elements, equal_to([(1, 5), (1, 10)]), 'bad')
+ def test_tags_with_exception_handling_then_resource_hint(self):
+ class TagHint(ResourceHint):
+ urn = 'beam:resources:tags:v1'
+
+ ResourceHint.register_resource_hint('tags', TagHint)
+ with beam.Pipeline() as pipeline:
+ ok, unused_errors = (
+ pipeline
+ | beam.Create([1])
+ | beam.Map(lambda x: x)
+ .with_exception_handling()
+ .with_resource_hints(tags='test_tag')
+ )
+ pd = ok.producer.transform
+ self.assertIsInstance(pd, beam.transforms.core.ParDo)
+ while hasattr(pd.fn, 'fn'):
+ pd = pd.fn
+ self.assertEqual(
+ pd.get_resource_hints(),
+ {'beam:resources:tags:v1': b'test_tag'},
+ )
+
+ def test_tags_with_exception_handling_timeout_then_resource_hint(self):
+ class TagHint(ResourceHint):
+ urn = 'beam:resources:tags:v1'
+
+ ResourceHint.register_resource_hint('tags', TagHint)
+ with beam.Pipeline() as pipeline:
+ ok, unused_errors = (
+ pipeline
+ | beam.Create([1])
+ | beam.Map(lambda x: x)
+ .with_exception_handling(timeout=1)
+ .with_resource_hints(tags='test_tag')
+ )
+ pd = ok.producer.transform
+ self.assertIsInstance(pd, beam.transforms.core.ParDo)
+ while hasattr(pd.fn, 'fn'):
+ pd = pd.fn
+ self.assertEqual(
+ pd.get_resource_hints(),
+ {'beam:resources:tags:v1': b'test_tag'},
+ )
+
+ def test_tags_with_resource_hint_then_exception_handling(self):
+ class TagHint(ResourceHint):
+ urn = 'beam:resources:tags:v1'
+
+ ResourceHint.register_resource_hint('tags', TagHint)
+ with beam.Pipeline() as pipeline:
+ ok, unused_errors = (
+ pipeline
+ | beam.Create([1])
+ | beam.Map(lambda x: x)
+ .with_resource_hints(tags='test_tag')
+ .with_exception_handling()
+ )
+ pd = ok.producer.transform
+ self.assertIsInstance(pd, beam.transforms.core.ParDo)
+ while hasattr(pd.fn, 'fn'):
+ pd = pd.fn
+ self.assertEqual(
+ pd.get_resource_hints(),
+ {'beam:resources:tags:v1': b'test_tag'},
+ )
+
+ def test_tags_with_resource_hint_then_exception_handling_timeout(self):
+ class TagHint(ResourceHint):
+ urn = 'beam:resources:tags:v1'
+
+ ResourceHint.register_resource_hint('tags', TagHint)
+ with beam.Pipeline() as pipeline:
+ ok, unused_errors = (
+ pipeline
+ | beam.Create([1])
+ | beam.Map(lambda x: x)
+ .with_resource_hints(tags='test_tag')
+ .with_exception_handling(timeout=1)
+ )
+ pd = ok.producer.transform
+ self.assertIsInstance(pd, beam.transforms.core.ParDo)
+ while hasattr(pd.fn, 'fn'):
+ pd = pd.fn
+ self.assertEqual(
+ pd.get_resource_hints(),
+ {'beam:resources:tags:v1': b'test_tag'},
+ )
+
def test_callablewrapper_typehint():
T = TypeVar("T")
diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py
index d2cf836..cac8a8f 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -1164,6 +1164,10 @@
def __rrshift__(self, label):
return _NamedPTransform(self.transform, label)
+ def with_resource_hints(self, **kwargs):
+ self.transform.with_resource_hints(**kwargs)
+ return self
+
def __getattr__(self, attr):
transform_attr = getattr(self.transform, attr)
if callable(transform_attr):