Fix custom coders not being used in Reshuffle (non global window) (#33363)
* Fix typehint in ReshufflePerKey on global window setting.
* Only update the type hint on global window setting. Need more work in non-global windows.
* Apply yapf
* Fix some failed tests.
* Revert change to setup.py
* Fix custom coders not being used in reshuffle in non-global windows
* Revert changes in setup.py. Reformat.
* Make WindowedValue a generic class. Support its conversion to the correct type constraint in Beam.
* Cython does not support Python generic class. Add a subclass as a workroundand keep it un-cythonized.
* Add comments
* Fix type error.
* Remove the base class of WindowedValue in TypedWindowedValue.
* Move TypedWindowedValue out from windowed_value.py
* Revise the comments
* Fix the module location when matching.
* Fix test failure where __name__ of a type alias not found in python 3.9
* Add a note about the window coder.
---------
Co-authored-by: Robert Bradshaw <robertwb@gmail.com>
diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py
index a0c55da..724f268 100644
--- a/sdks/python/apache_beam/coders/coders.py
+++ b/sdks/python/apache_beam/coders/coders.py
@@ -1402,6 +1402,14 @@
return hash(
(self.wrapped_value_coder, self.timestamp_coder, self.window_coder))
+ @classmethod
+ def from_type_hint(cls, typehint, registry):
+ # type: (Any, CoderRegistry) -> WindowedValueCoder
+ # Ideally this'd take two parameters so that one could hint at
+ # the window type as well instead of falling back to the
+ # pickle coders.
+ return cls(registry.get_coder(typehint.inner_type))
+
Coder.register_structured_urn(
common_urns.coders.WINDOWED_VALUE.urn, WindowedValueCoder)
diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py
index 1667cb7..892f508d 100644
--- a/sdks/python/apache_beam/coders/typecoders.py
+++ b/sdks/python/apache_beam/coders/typecoders.py
@@ -94,6 +94,8 @@
self._register_coder_internal(str, coders.StrUtf8Coder)
self._register_coder_internal(typehints.TupleConstraint, coders.TupleCoder)
self._register_coder_internal(typehints.DictConstraint, coders.MapCoder)
+ self._register_coder_internal(
+ typehints.WindowedTypeConstraint, coders.WindowedValueCoder)
# Default fallback coders applied in that order until the first matching
# coder found.
default_fallback_coders = [
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index 43d4a6c..c9fd2c7 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -74,6 +74,7 @@
from apache_beam.transforms.window import TimestampedValue
from apache_beam.typehints import trivial_inference
from apache_beam.typehints.decorators import get_signature
+from apache_beam.typehints.native_type_compatibility import TypedWindowedValue
from apache_beam.typehints.sharded_key_type import ShardedKeyType
from apache_beam.utils import shared
from apache_beam.utils import windowed_value
@@ -972,9 +973,8 @@
key, windowed_values = element
return [wv.with_value((key, wv.value)) for wv in windowed_values]
- # TODO(https://github.com/apache/beam/issues/33356): Support reshuffling
- # unpicklable objects with a non-global window setting.
- ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any)
+ ungrouped = pcoll | Map(reify_timestamps).with_input_types(
+ Tuple[K, V]).with_output_types(Tuple[K, TypedWindowedValue[V]])
# TODO(https://github.com/apache/beam/issues/19785) Using global window as
# one of the standard window. This is to mitigate the Dataflow Java Runner
diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py
index 7f166f7..db73310 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -1010,32 +1010,33 @@
equal_to(expected_data),
label="formatted_after_reshuffle")
+ global _Unpicklable
+ global _UnpicklableCoder
+
+ class _Unpicklable(object):
+ def __init__(self, value):
+ self.value = value
+
+ def __getstate__(self):
+ raise NotImplementedError()
+
+ def __setstate__(self, state):
+ raise NotImplementedError()
+
+ class _UnpicklableCoder(beam.coders.Coder):
+ def encode(self, value):
+ return str(value.value).encode()
+
+ def decode(self, encoded):
+ return _Unpicklable(int(encoded.decode()))
+
+ def to_type_hint(self):
+ return _Unpicklable
+
+ def is_deterministic(self):
+ return True
+
def test_reshuffle_unpicklable_in_global_window(self):
- global _Unpicklable
-
- class _Unpicklable(object):
- def __init__(self, value):
- self.value = value
-
- def __getstate__(self):
- raise NotImplementedError()
-
- def __setstate__(self, state):
- raise NotImplementedError()
-
- class _UnpicklableCoder(beam.coders.Coder):
- def encode(self, value):
- return str(value.value).encode()
-
- def decode(self, encoded):
- return _Unpicklable(int(encoded.decode()))
-
- def to_type_hint(self):
- return _Unpicklable
-
- def is_deterministic(self):
- return True
-
beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)
with TestPipeline() as pipeline:
@@ -1049,6 +1050,20 @@
| beam.Map(lambda u: u.value * 10))
assert_that(result, equal_to(expected_data))
+ def test_reshuffle_unpicklable_in_non_global_window(self):
+ beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)
+
+ with TestPipeline() as pipeline:
+ data = [_Unpicklable(i) for i in range(5)]
+ expected_data = [0, 0, 0, 10, 10, 10, 20, 20, 20, 30, 30, 30, 40, 40, 40]
+ result = (
+ pipeline
+ | beam.Create(data)
+ | beam.WindowInto(window.SlidingWindows(size=3, period=1))
+ | beam.Reshuffle()
+ | beam.Map(lambda u: u.value * 10))
+ assert_that(result, equal_to(expected_data))
+
class WithKeysTest(unittest.TestCase):
def setUp(self):
diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py
index 6f704b3..381d4f7 100644
--- a/sdks/python/apache_beam/typehints/native_type_compatibility.py
+++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py
@@ -24,9 +24,13 @@
import sys
import types
import typing
+from typing import Generic
+from typing import TypeVar
from apache_beam.typehints import typehints
+T = TypeVar('T')
+
_LOGGER = logging.getLogger(__name__)
# Describes an entry in the type map in convert_to_beam_type.
@@ -216,6 +220,18 @@
return typ
+# During type inference of WindowedValue, we need to pass in the inner value
+# type. This cannot be achieved immediately with WindowedValue class because it
+# is not parameterized. Changing it to a generic class (e.g. WindowedValue[T])
+# could work in theory. However, the class is cythonized and it seems that
+# cython does not handle generic classes well.
+# The workaround here is to create a separate class solely for the type
+# inference purpose. This class should never be used for creating instances.
+class TypedWindowedValue(Generic[T]):
+ def __init__(self, *args, **kwargs):
+ raise NotImplementedError("This class is solely for type inference")
+
+
def convert_to_beam_type(typ):
"""Convert a given typing type to a Beam type.
@@ -267,6 +283,12 @@
# TODO(https://github.com/apache/beam/issues/20076): Currently unhandled.
_LOGGER.info('Converting NewType type hint to Any: "%s"', typ)
return typehints.Any
+ elif typ_module == 'apache_beam.typehints.native_type_compatibility' and \
+ getattr(typ, "__name__", typ.__origin__.__name__) == 'TypedWindowedValue':
+ # Need to pass through WindowedValue class so that it can be converted
+ # to the correct type constraint in Beam
+ # This is needed to fix https://github.com/apache/beam/issues/33356
+ pass
elif (typ_module != 'typing') and (typ_module != 'collections.abc'):
# Only translate types from the typing and collections.abc modules.
return typ
@@ -324,6 +346,10 @@
match=_match_is_exactly_collection,
arity=1,
beam_type=typehints.Collection),
+ _TypeMapEntry(
+ match=_match_issubclass(TypedWindowedValue),
+ arity=1,
+ beam_type=typehints.WindowedValue),
]
# Find the first matching entry.
diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py
index 0e18e88..a65a0f7 100644
--- a/sdks/python/apache_beam/typehints/typehints.py
+++ b/sdks/python/apache_beam/typehints/typehints.py
@@ -1213,6 +1213,15 @@
repr(self.inner_type),
instance.value.__class__.__name__))
+ def bind_type_variables(self, bindings):
+ bound_inner_type = bind_type_variables(self.inner_type, bindings)
+ if bound_inner_type == self.inner_type:
+ return self
+ return WindowedValue[bound_inner_type]
+
+ def __repr__(self):
+ return 'WindowedValue[%s]' % repr(self.inner_type)
+
class GeneratorHint(IteratorHint):
"""A Generator type hint.