[Python] Fix: Propagate resource hints through with_exception_handling (#36090)

* Implement two-way propagation for resource hint, fix Python with_exception_handling + JAX-on-Beam = pipeline failure

* Add unit test for resource hint propagation in ParDo.

* Propagate resource hint in a more intuiative way
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 1bfc732..2304faf 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -1678,7 +1678,8 @@
         timeout,
         error_handler,
         on_failure_callback,
-        allow_unsafe_userstate_in_process)
+        allow_unsafe_userstate_in_process,
+        self.get_resource_hints())
 
   def with_error_handler(self, error_handler, **exception_handling_kwargs):
     """An alias for `with_exception_handling(error_handler=error_handler, ...)`
@@ -2284,7 +2285,8 @@
       timeout,
       error_handler,
       on_failure_callback,
-      allow_unsafe_userstate_in_process):
+      allow_unsafe_userstate_in_process,
+      resource_hints):
     if partial and use_subprocess:
       raise ValueError('partial and use_subprocess are mutually incompatible.')
     self._fn = fn
@@ -2301,6 +2303,7 @@
     self._error_handler = error_handler
     self._on_failure_callback = on_failure_callback
     self._allow_unsafe_userstate_in_process = allow_unsafe_userstate_in_process
+    self._resource_hints = resource_hints
 
   def expand(self, pcoll):
     if self._allow_unsafe_userstate_in_process:
@@ -2317,17 +2320,23 @@
       wrapped_fn = _TimeoutDoFn(self._fn, timeout=self._timeout)
     else:
       wrapped_fn = self._fn
-    result = pcoll | ParDo(
+    pardo = ParDo(
         _ExceptionHandlingWrapperDoFn(
             wrapped_fn,
             self._dead_letter_tag,
             self._exc_class,
             self._partial,
             self._on_failure_callback,
-            self._allow_unsafe_userstate_in_process),
+            self._allow_unsafe_userstate_in_process,
+        ),
         *self._args,
-        **self._kwargs).with_outputs(
-            self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True)
+        **self._kwargs,
+    )
+    # This is the fix: propagate hints.
+    pardo.get_resource_hints().update(self._resource_hints)
+
+    result = pcoll | pardo.with_outputs(
+        self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True)
     #TODO(BEAM-18957): Fix when type inference supports tagged outputs.
     result[self._main_tag].element_type = self._fn.infer_output_type(
         pcoll.element_type)
diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py
index 3e5e767..0d680c9 100644
--- a/sdks/python/apache_beam/transforms/core_test.py
+++ b/sdks/python/apache_beam/transforms/core_test.py
@@ -30,6 +30,7 @@
 from apache_beam.coders import coders
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
+from apache_beam.transforms.resources import ResourceHint
 from apache_beam.transforms.userstate import BagStateSpec
 from apache_beam.transforms.userstate import ReadModifyWriteStateSpec
 from apache_beam.transforms.userstate import TimerSpec
@@ -416,6 +417,94 @@
       assert_that(good, equal_to([0, 1, 2]), 'good')
       assert_that(bad_elements, equal_to([(1, 5), (1, 10)]), 'bad')
 
+  def test_tags_with_exception_handling_then_resource_hint(self):
+    class TagHint(ResourceHint):
+      urn = 'beam:resources:tags:v1'
+
+    ResourceHint.register_resource_hint('tags', TagHint)
+    with beam.Pipeline() as pipeline:
+      ok, unused_errors = (
+        pipeline
+        | beam.Create([1])
+        | beam.Map(lambda x: x)
+        .with_exception_handling()
+        .with_resource_hints(tags='test_tag')
+      )
+    pd = ok.producer.transform
+    self.assertIsInstance(pd, beam.transforms.core.ParDo)
+    while hasattr(pd.fn, 'fn'):
+      pd = pd.fn
+    self.assertEqual(
+        pd.get_resource_hints(),
+        {'beam:resources:tags:v1': b'test_tag'},
+    )
+
+  def test_tags_with_exception_handling_timeout_then_resource_hint(self):
+    class TagHint(ResourceHint):
+      urn = 'beam:resources:tags:v1'
+
+    ResourceHint.register_resource_hint('tags', TagHint)
+    with beam.Pipeline() as pipeline:
+      ok, unused_errors = (
+        pipeline
+        | beam.Create([1])
+        | beam.Map(lambda x: x)
+        .with_exception_handling(timeout=1)
+        .with_resource_hints(tags='test_tag')
+      )
+    pd = ok.producer.transform
+    self.assertIsInstance(pd, beam.transforms.core.ParDo)
+    while hasattr(pd.fn, 'fn'):
+      pd = pd.fn
+    self.assertEqual(
+        pd.get_resource_hints(),
+        {'beam:resources:tags:v1': b'test_tag'},
+    )
+
+  def test_tags_with_resource_hint_then_exception_handling(self):
+    class TagHint(ResourceHint):
+      urn = 'beam:resources:tags:v1'
+
+    ResourceHint.register_resource_hint('tags', TagHint)
+    with beam.Pipeline() as pipeline:
+      ok, unused_errors = (
+        pipeline
+        | beam.Create([1])
+        | beam.Map(lambda x: x)
+        .with_resource_hints(tags='test_tag')
+        .with_exception_handling()
+      )
+    pd = ok.producer.transform
+    self.assertIsInstance(pd, beam.transforms.core.ParDo)
+    while hasattr(pd.fn, 'fn'):
+      pd = pd.fn
+    self.assertEqual(
+        pd.get_resource_hints(),
+        {'beam:resources:tags:v1': b'test_tag'},
+    )
+
+  def test_tags_with_resource_hint_then_exception_handling_timeout(self):
+    class TagHint(ResourceHint):
+      urn = 'beam:resources:tags:v1'
+
+    ResourceHint.register_resource_hint('tags', TagHint)
+    with beam.Pipeline() as pipeline:
+      ok, unused_errors = (
+        pipeline
+        | beam.Create([1])
+        | beam.Map(lambda x: x)
+        .with_resource_hints(tags='test_tag')
+        .with_exception_handling(timeout=1)
+      )
+    pd = ok.producer.transform
+    self.assertIsInstance(pd, beam.transforms.core.ParDo)
+    while hasattr(pd.fn, 'fn'):
+      pd = pd.fn
+    self.assertEqual(
+        pd.get_resource_hints(),
+        {'beam:resources:tags:v1': b'test_tag'},
+    )
+
 
 def test_callablewrapper_typehint():
   T = TypeVar("T")
diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py
index d2cf836..cac8a8f 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -1164,6 +1164,10 @@
   def __rrshift__(self, label):
     return _NamedPTransform(self.transform, label)
 
+  def with_resource_hints(self, **kwargs):
+    self.transform.with_resource_hints(**kwargs)
+    return self
+
   def __getattr__(self, attr):
     transform_attr = getattr(self.transform, attr)
     if callable(transform_attr):