Fix custom coders not being used in Reshuffle (non global window) (#33363)

* Fix typehint in ReshufflePerKey on global window setting.

* Only update the type hint on global window setting. Need more work in non-global windows.

* Apply yapf

* Fix some failed tests.

* Revert change to setup.py

* Fix custom coders not being used in reshuffle in non-global windows

* Revert changes in setup.py. Reformat.

* Make WindowedValue a generic class. Support its conversion to the correct type constraint in Beam.

* Cython does not support Python generic class. Add a subclass as a workroundand keep it un-cythonized.

* Add comments

* Fix type error.

* Remove the base class of WindowedValue in TypedWindowedValue.

* Move TypedWindowedValue out from windowed_value.py

* Revise the comments

* Fix the module location when matching.

* Fix test failure where __name__ of a type alias not found in python 3.9

* Add a note about the window coder.

---------

Co-authored-by: Robert Bradshaw <robertwb@gmail.com>
diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py
index a0c55da..724f268 100644
--- a/sdks/python/apache_beam/coders/coders.py
+++ b/sdks/python/apache_beam/coders/coders.py
@@ -1402,6 +1402,14 @@
     return hash(
         (self.wrapped_value_coder, self.timestamp_coder, self.window_coder))
 
+  @classmethod
+  def from_type_hint(cls, typehint, registry):
+    # type: (Any, CoderRegistry) -> WindowedValueCoder
+    # Ideally this'd take two parameters so that one could hint at
+    # the window type as well instead of falling back to the
+    # pickle coders.
+    return cls(registry.get_coder(typehint.inner_type))
+
 
 Coder.register_structured_urn(
     common_urns.coders.WINDOWED_VALUE.urn, WindowedValueCoder)
diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py
index 1667cb7..892f508d 100644
--- a/sdks/python/apache_beam/coders/typecoders.py
+++ b/sdks/python/apache_beam/coders/typecoders.py
@@ -94,6 +94,8 @@
     self._register_coder_internal(str, coders.StrUtf8Coder)
     self._register_coder_internal(typehints.TupleConstraint, coders.TupleCoder)
     self._register_coder_internal(typehints.DictConstraint, coders.MapCoder)
+    self._register_coder_internal(
+        typehints.WindowedTypeConstraint, coders.WindowedValueCoder)
     # Default fallback coders applied in that order until the first matching
     # coder found.
     default_fallback_coders = [
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index 43d4a6c..c9fd2c7 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -74,6 +74,7 @@
 from apache_beam.transforms.window import TimestampedValue
 from apache_beam.typehints import trivial_inference
 from apache_beam.typehints.decorators import get_signature
+from apache_beam.typehints.native_type_compatibility import TypedWindowedValue
 from apache_beam.typehints.sharded_key_type import ShardedKeyType
 from apache_beam.utils import shared
 from apache_beam.utils import windowed_value
@@ -972,9 +973,8 @@
         key, windowed_values = element
         return [wv.with_value((key, wv.value)) for wv in windowed_values]
 
-      # TODO(https://github.com/apache/beam/issues/33356): Support reshuffling
-      # unpicklable objects with a non-global window setting.
-      ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any)
+      ungrouped = pcoll | Map(reify_timestamps).with_input_types(
+          Tuple[K, V]).with_output_types(Tuple[K, TypedWindowedValue[V]])
 
     # TODO(https://github.com/apache/beam/issues/19785) Using global window as
     # one of the standard window. This is to mitigate the Dataflow Java Runner
diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py
index 7f166f7..db73310 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -1010,32 +1010,33 @@
           equal_to(expected_data),
           label="formatted_after_reshuffle")
 
+  global _Unpicklable
+  global _UnpicklableCoder
+
+  class _Unpicklable(object):
+    def __init__(self, value):
+      self.value = value
+
+    def __getstate__(self):
+      raise NotImplementedError()
+
+    def __setstate__(self, state):
+      raise NotImplementedError()
+
+  class _UnpicklableCoder(beam.coders.Coder):
+    def encode(self, value):
+      return str(value.value).encode()
+
+    def decode(self, encoded):
+      return _Unpicklable(int(encoded.decode()))
+
+    def to_type_hint(self):
+      return _Unpicklable
+
+    def is_deterministic(self):
+      return True
+
   def test_reshuffle_unpicklable_in_global_window(self):
-    global _Unpicklable
-
-    class _Unpicklable(object):
-      def __init__(self, value):
-        self.value = value
-
-      def __getstate__(self):
-        raise NotImplementedError()
-
-      def __setstate__(self, state):
-        raise NotImplementedError()
-
-    class _UnpicklableCoder(beam.coders.Coder):
-      def encode(self, value):
-        return str(value.value).encode()
-
-      def decode(self, encoded):
-        return _Unpicklable(int(encoded.decode()))
-
-      def to_type_hint(self):
-        return _Unpicklable
-
-      def is_deterministic(self):
-        return True
-
     beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)
 
     with TestPipeline() as pipeline:
@@ -1049,6 +1050,20 @@
           | beam.Map(lambda u: u.value * 10))
       assert_that(result, equal_to(expected_data))
 
+  def test_reshuffle_unpicklable_in_non_global_window(self):
+    beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)
+
+    with TestPipeline() as pipeline:
+      data = [_Unpicklable(i) for i in range(5)]
+      expected_data = [0, 0, 0, 10, 10, 10, 20, 20, 20, 30, 30, 30, 40, 40, 40]
+      result = (
+          pipeline
+          | beam.Create(data)
+          | beam.WindowInto(window.SlidingWindows(size=3, period=1))
+          | beam.Reshuffle()
+          | beam.Map(lambda u: u.value * 10))
+      assert_that(result, equal_to(expected_data))
+
 
 class WithKeysTest(unittest.TestCase):
   def setUp(self):
diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py
index 6f704b3..381d4f7 100644
--- a/sdks/python/apache_beam/typehints/native_type_compatibility.py
+++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py
@@ -24,9 +24,13 @@
 import sys
 import types
 import typing
+from typing import Generic
+from typing import TypeVar
 
 from apache_beam.typehints import typehints
 
+T = TypeVar('T')
+
 _LOGGER = logging.getLogger(__name__)
 
 # Describes an entry in the type map in convert_to_beam_type.
@@ -216,6 +220,18 @@
   return typ
 
 
+# During type inference of WindowedValue, we need to pass in the inner value
+# type. This cannot be achieved immediately with WindowedValue class because it
+# is not parameterized. Changing it to a generic class (e.g. WindowedValue[T])
+# could work in theory. However, the class is cythonized and it seems that
+# cython does not handle generic classes well.
+# The workaround here is to create a separate class solely for the type
+# inference purpose. This class should never be used for creating instances.
+class TypedWindowedValue(Generic[T]):
+  def __init__(self, *args, **kwargs):
+    raise NotImplementedError("This class is solely for type inference")
+
+
 def convert_to_beam_type(typ):
   """Convert a given typing type to a Beam type.
 
@@ -267,6 +283,12 @@
     # TODO(https://github.com/apache/beam/issues/20076): Currently unhandled.
     _LOGGER.info('Converting NewType type hint to Any: "%s"', typ)
     return typehints.Any
+  elif typ_module == 'apache_beam.typehints.native_type_compatibility' and \
+      getattr(typ, "__name__", typ.__origin__.__name__) == 'TypedWindowedValue':
+    # Need to pass through WindowedValue class so that it can be converted
+    # to the correct type constraint in Beam
+    # This is needed to fix https://github.com/apache/beam/issues/33356
+    pass
   elif (typ_module != 'typing') and (typ_module != 'collections.abc'):
     # Only translate types from the typing and collections.abc modules.
     return typ
@@ -324,6 +346,10 @@
           match=_match_is_exactly_collection,
           arity=1,
           beam_type=typehints.Collection),
+      _TypeMapEntry(
+          match=_match_issubclass(TypedWindowedValue),
+          arity=1,
+          beam_type=typehints.WindowedValue),
   ]
 
   # Find the first matching entry.
diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py
index 0e18e88..a65a0f7 100644
--- a/sdks/python/apache_beam/typehints/typehints.py
+++ b/sdks/python/apache_beam/typehints/typehints.py
@@ -1213,6 +1213,15 @@
               repr(self.inner_type),
               instance.value.__class__.__name__))
 
+  def bind_type_variables(self, bindings):
+    bound_inner_type = bind_type_variables(self.inner_type, bindings)
+    if bound_inner_type == self.inner_type:
+      return self
+    return WindowedValue[bound_inner_type]
+
+  def __repr__(self):
+    return 'WindowedValue[%s]' % repr(self.inner_type)
+
 
 class GeneratorHint(IteratorHint):
   """A Generator type hint.