[BEAM-10258] Support type hint annotations on PTransform's expand() (#12009)

* [BEAM-10258] Support type hint annotations on PTransform's expand()

* Fixup: apply YAPF

* Moving PCollectionTypeConstraint to typehints.py

* Uses Generic[T] instead of PCollectionTypeConstraint

* Fixup: apply YAPF

* Remove unused imports

* Force user to wrap typehints in PCollections

* Add unit tests for various usages of typehints on PTransforms

* Add tests that use typehints on real pipelines

* Fixup: apply YAPF

* Fix bad merge

* Support PDone, PBegin, and better handling of error cases

* Fix test syntax

* Refactors strip_pcoll_input() and strip_pcoll_output() to a shared function

* Add unit tests

* Add more tests

* Add website documentation

* Fix linting issues

* Fix linting issue by using multi-line function annotations

* Fix more lint errors

* Fix import order, and other changes for PR

* Fix ungrouped-imports error

* Alphabetically order the imports

* Fixup: apply YAPF

* Fixes a bug where a type can have an empty __args__ attribute

* Fix bug in website snippet code

* Fixup: apply YAPF

* Fixup: apply YAPF

* Fix NoneType error

* Fix NoneType error part 2

* Use classes instead of strings during typecheck, and add tests

* Resolve circular import error and fix readability issues

* Fix lint errors

* Add back accidentally removed test

* Support None as an output annotation

* Show incorrect type in error message

Co-authored-by: Udi Meiri <udim@users.noreply.github.com>

* Allow Pipeline as an input

* Fix import bug

* Alphabetically order imports inside function (but really this is just to force re-run the tests)

* Display warning instead of throwing error for oddly formed type hints

* Convert to Beam types

* Add test for generic TypeVars

* Fix bug by skipping DoOutputsTuple

* Fix typo

* Add test for DoOutputsTuple

* Fix lint errors

Co-authored-by: Udi Meiri <udim@users.noreply.github.com>
diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test_py3.py b/sdks/python/apache_beam/examples/snippets/snippets_test_py3.py
index 0f0b668..5eb1d4b 100644
--- a/sdks/python/apache_beam/examples/snippets/snippets_test_py3.py
+++ b/sdks/python/apache_beam/examples/snippets/snippets_test_py3.py
@@ -96,6 +96,18 @@
       ids = numbers | 'to_id' >> beam.Map(my_fn)
       # [END type_hints_map_annotations]
 
+    # Example using an annotated PTransform.
+    with self.assertRaises(typehints.TypeCheckError):
+      # [START type_hints_ptransforms]
+      from apache_beam.pvalue import PCollection
+
+      class IntToStr(beam.PTransform):
+        def expand(self, pcoll: PCollection[int]) -> PCollection[str]:
+          return pcoll | beam.Map(lambda elem: str(elem))
+
+      ids = numbers | 'convert to str' >> IntToStr()
+      # [END type_hints_ptransforms]
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py
index 2ab121e..4a0e80f 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -68,13 +68,16 @@
 from apache_beam.internal import pickler
 from apache_beam.internal import util
 from apache_beam.portability import python_urns
+from apache_beam.pvalue import DoOutputsTuple
 from apache_beam.transforms.display import DisplayDataItem
 from apache_beam.transforms.display import HasDisplayData
 from apache_beam.typehints import native_type_compatibility
 from apache_beam.typehints import typehints
+from apache_beam.typehints.decorators import IOTypeHints
 from apache_beam.typehints.decorators import TypeCheckError
 from apache_beam.typehints.decorators import WithTypeHints
 from apache_beam.typehints.decorators import get_signature
+from apache_beam.typehints.decorators import get_type_hints
 from apache_beam.typehints.decorators import getcallargs_forhints
 from apache_beam.typehints.trivial_inference import instance_to_type
 from apache_beam.typehints.typehints import validate_composite_type_param
@@ -350,6 +353,14 @@
     # type: () -> str
     return self.__class__.__name__
 
+  def default_type_hints(self):
+    fn_type_hints = IOTypeHints.from_callable(self.expand)
+    if fn_type_hints is not None:
+      fn_type_hints = fn_type_hints.strip_pcoll()
+
+    # Prefer class decorator type hints for backwards compatibility.
+    return get_type_hints(self.__class__).with_defaults(fn_type_hints)
+
   def with_input_types(self, input_type_hint):
     """Annotates the input type of a :class:`PTransform` with a type-hint.
 
@@ -419,6 +430,8 @@
     root_hint = (
         arg_hints[0] if len(arg_hints) == 1 else arg_hints or kwarg_hints)
     for context, pvalue_, hint in _ZipPValues().visit(pvalueish, root_hint):
+      if isinstance(pvalue_, DoOutputsTuple):
+        continue
       if pvalue_.element_type is None:
         # TODO(robertwb): It's a bug that we ever get here. (typecheck)
         continue
diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py
index dad0a31..4cd7681 100644
--- a/sdks/python/apache_beam/typehints/decorators.py
+++ b/sdks/python/apache_beam/typehints/decorators.py
@@ -105,6 +105,7 @@
 from typing import Optional
 from typing import Tuple
 from typing import TypeVar
+from typing import Union
 
 from apache_beam.typehints import native_type_compatibility
 from apache_beam.typehints import typehints
@@ -378,6 +379,75 @@
         self.output_types and len(self.output_types[0]) == 1 and
         not self.output_types[1])
 
+  def strip_pcoll(self):
+    from apache_beam.pipeline import Pipeline
+    from apache_beam.pvalue import PBegin
+    from apache_beam.pvalue import PDone
+
+    return self.strip_pcoll_helper(self.input_types,
+                                   self._has_input_types,
+                                   'input_types',
+                                    [Pipeline, PBegin],
+                                   'This input type hint will be ignored '
+                                   'and not used for type-checking purposes. '
+                                   'Typically, input type hints for a '
+                                   'PTransform are single (or nested) types '
+                                   'wrapped by a PCollection, or PBegin.',
+                                   'strip_pcoll_input()').\
+                strip_pcoll_helper(self.output_types,
+                                   self.has_simple_output_type,
+                                   'output_types',
+                                   [PDone, None],
+                                   'This output type hint will be ignored '
+                                   'and not used for type-checking purposes. '
+                                   'Typically, output type hints for a '
+                                   'PTransform are single (or nested) types '
+                                   'wrapped by a PCollection, PDone, or None.',
+                                   'strip_pcoll_output()')
+
+  def strip_pcoll_helper(
+      self,
+      my_type,            # type: any
+      has_my_type,        # type: Callable[[], bool]
+      my_key,             # type: str
+      special_containers,   # type: List[Union[PBegin, PDone, PCollection]]
+      error_str,          # type: str
+      source_str          # type: str
+      ):
+    # type: (...) -> IOTypeHints
+    from apache_beam.pvalue import PCollection
+
+    if not has_my_type() or not my_type or len(my_type[0]) != 1:
+      return self
+
+    my_type = my_type[0][0]
+
+    if isinstance(my_type, typehints.AnyTypeConstraint):
+      return self
+
+    special_containers += [PCollection]
+    kwarg_dict = {}
+
+    if (my_type not in special_containers and
+        getattr(my_type, '__origin__', None) != PCollection):
+      logging.warning(error_str + ' Got: %s instead.' % my_type)
+      kwarg_dict[my_key] = None
+      return self._replace(
+          origin=self._make_origin([self], tb=False, msg=[source_str]),
+          **kwarg_dict)
+
+    if (getattr(my_type, '__args__', -1) in [-1, None] or
+        len(my_type.__args__) == 0):
+      # e.g. PCollection (or PBegin/PDone)
+      kwarg_dict[my_key] = ((typehints.Any, ), {})
+    else:
+      # e.g. PCollection[type]
+      kwarg_dict[my_key] = ((convert_to_beam_type(my_type.__args__[0]), ), {})
+
+    return self._replace(
+        origin=self._make_origin([self], tb=False, msg=[source_str]),
+        **kwarg_dict)
+
   def strip_iterable(self):
     # type: () -> IOTypeHints
 
diff --git a/sdks/python/apache_beam/typehints/typed_pipeline_test_py3.py b/sdks/python/apache_beam/typehints/typed_pipeline_test_py3.py
index 2016871..e12930d 100644
--- a/sdks/python/apache_beam/typehints/typed_pipeline_test_py3.py
+++ b/sdks/python/apache_beam/typehints/typed_pipeline_test_py3.py
@@ -22,6 +22,7 @@
 
 from __future__ import absolute_import
 
+import typing
 import unittest
 
 import apache_beam as beam
@@ -257,6 +258,135 @@
     result = [1, 2, 3] | beam.FlatMap(fn) | beam.Map(fn2)
     self.assertCountEqual([4, 6], result)
 
+  def test_typed_ptransform_with_no_error(self):
+    class StrToInt(beam.PTransform):
+      def expand(
+          self,
+          pcoll: beam.pvalue.PCollection[str]) -> beam.pvalue.PCollection[int]:
+        return pcoll | beam.Map(lambda x: int(x))
+
+    class IntToStr(beam.PTransform):
+      def expand(
+          self,
+          pcoll: beam.pvalue.PCollection[int]) -> beam.pvalue.PCollection[str]:
+        return pcoll | beam.Map(lambda x: str(x))
+
+    _ = ['1', '2', '3'] | StrToInt() | IntToStr()
+
+  def test_typed_ptransform_with_bad_typehints(self):
+    class StrToInt(beam.PTransform):
+      def expand(
+          self,
+          pcoll: beam.pvalue.PCollection[str]) -> beam.pvalue.PCollection[int]:
+        return pcoll | beam.Map(lambda x: int(x))
+
+    class IntToStr(beam.PTransform):
+      def expand(
+          self,
+          pcoll: beam.pvalue.PCollection[str]) -> beam.pvalue.PCollection[str]:
+        return pcoll | beam.Map(lambda x: str(x))
+
+    with self.assertRaisesRegex(typehints.TypeCheckError,
+                                "Input type hint violation at IntToStr: "
+                                "expected <class 'str'>, got <class 'int'>"):
+      _ = ['1', '2', '3'] | StrToInt() | IntToStr()
+
+  def test_typed_ptransform_with_bad_input(self):
+    class StrToInt(beam.PTransform):
+      def expand(
+          self,
+          pcoll: beam.pvalue.PCollection[str]) -> beam.pvalue.PCollection[int]:
+        return pcoll | beam.Map(lambda x: int(x))
+
+    class IntToStr(beam.PTransform):
+      def expand(
+          self,
+          pcoll: beam.pvalue.PCollection[int]) -> beam.pvalue.PCollection[str]:
+        return pcoll | beam.Map(lambda x: str(x))
+
+    with self.assertRaisesRegex(typehints.TypeCheckError,
+                                "Input type hint violation at StrToInt: "
+                                "expected <class 'str'>, got <class 'int'>"):
+      # Feed integers to a PTransform that expects strings
+      _ = [1, 2, 3] | StrToInt() | IntToStr()
+
+  def test_typed_ptransform_with_partial_typehints(self):
+    class StrToInt(beam.PTransform):
+      def expand(self, pcoll) -> beam.pvalue.PCollection[int]:
+        return pcoll | beam.Map(lambda x: int(x))
+
+    class IntToStr(beam.PTransform):
+      def expand(
+          self,
+          pcoll: beam.pvalue.PCollection[int]) -> beam.pvalue.PCollection[str]:
+        return pcoll | beam.Map(lambda x: str(x))
+
+    # Feed integers to a PTransform that should expect strings
+    # but has no typehints so it expects any
+    _ = [1, 2, 3] | StrToInt() | IntToStr()
+
+  def test_typed_ptransform_with_bare_wrappers(self):
+    class StrToInt(beam.PTransform):
+      def expand(
+          self, pcoll: beam.pvalue.PCollection) -> beam.pvalue.PCollection:
+        return pcoll | beam.Map(lambda x: int(x))
+
+    class IntToStr(beam.PTransform):
+      def expand(
+          self,
+          pcoll: beam.pvalue.PCollection[int]) -> beam.pvalue.PCollection[str]:
+        return pcoll | beam.Map(lambda x: str(x))
+
+    _ = [1, 2, 3] | StrToInt() | IntToStr()
+
+  def test_typed_ptransform_with_no_typehints(self):
+    class StrToInt(beam.PTransform):
+      def expand(self, pcoll):
+        return pcoll | beam.Map(lambda x: int(x))
+
+    class IntToStr(beam.PTransform):
+      def expand(
+          self,
+          pcoll: beam.pvalue.PCollection[int]) -> beam.pvalue.PCollection[str]:
+        return pcoll | beam.Map(lambda x: str(x))
+
+    # Feed integers to a PTransform that should expect strings
+    # but has no typehints so it expects any
+    _ = [1, 2, 3] | StrToInt() | IntToStr()
+
+  def test_typed_ptransform_with_generic_annotations(self):
+    T = typing.TypeVar('T')
+
+    class IntToInt(beam.PTransform):
+      def expand(
+          self,
+          pcoll: beam.pvalue.PCollection[T]) -> beam.pvalue.PCollection[T]:
+        return pcoll | beam.Map(lambda x: x)
+
+    class IntToStr(beam.PTransform):
+      def expand(
+          self,
+          pcoll: beam.pvalue.PCollection[T]) -> beam.pvalue.PCollection[str]:
+        return pcoll | beam.Map(lambda x: str(x))
+
+    _ = [1, 2, 3] | IntToInt() | IntToStr()
+
+  def test_typed_ptransform_with_do_outputs_tuple_compiles(self):
+    class MyDoFn(beam.DoFn):
+      def process(self, element: int, *args, **kwargs):
+        if element % 2:
+          yield beam.pvalue.TaggedOutput('odd', 1)
+        else:
+          yield beam.pvalue.TaggedOutput('even', 1)
+
+    class MyPTransform(beam.PTransform):
+      def expand(self, pcoll: beam.pvalue.PCollection[int]):
+        return pcoll | beam.ParDo(MyDoFn()).with_outputs('odd', 'even')
+
+    # This test fails if you remove the following line from ptransform.py
+    # if isinstance(pvalue_, DoOutputsTuple): continue
+    _ = [1, 2, 3] | MyPTransform()
+
 
 class AnnotationsTest(unittest.TestCase):
   def test_pardo_dofn(self):
diff --git a/sdks/python/apache_beam/typehints/typehints_test_py3.py b/sdks/python/apache_beam/typehints/typehints_test_py3.py
index a7c23f0..5a36330 100644
--- a/sdks/python/apache_beam/typehints/typehints_test_py3.py
+++ b/sdks/python/apache_beam/typehints/typehints_test_py3.py
@@ -23,11 +23,19 @@
 from __future__ import absolute_import
 from __future__ import print_function
 
+import typing
 import unittest
 
+import apache_beam.typehints.typehints as typehints
+from apache_beam import Map
+from apache_beam import PTransform
+from apache_beam.pvalue import PBegin
+from apache_beam.pvalue import PCollection
+from apache_beam.pvalue import PDone
 from apache_beam.transforms.core import DoFn
 from apache_beam.typehints import KV
 from apache_beam.typehints import Iterable
+from apache_beam.typehints.typehints import Any
 
 
 class TestParDoAnnotations(unittest.TestCase):
@@ -46,11 +54,221 @@
       def process(self, element: int) -> Iterable[str]:
         pass
 
-    print(MyDoFn().get_type_hints())
     th = MyDoFn().get_type_hints()
     self.assertEqual(th.input_types, ((int, ), {}))
     self.assertEqual(th.output_types, ((str, ), {}))
 
 
+class TestPTransformAnnotations(unittest.TestCase):
+  def test_pep484_annotations(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: PCollection[int]) -> PCollection[str]:
+        return pcoll | Map(lambda num: str(num))
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(th.input_types, ((int, ), {}))
+    self.assertEqual(th.output_types, ((str, ), {}))
+
+  def test_annotations_without_input_pcollection_wrapper(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: int) -> PCollection[str]:
+        return pcoll | Map(lambda num: str(num))
+
+    error_str = (
+        r'This input type hint will be ignored and not used for '
+        r'type-checking purposes. Typically, input type hints for a '
+        r'PTransform are single (or nested) types wrapped by a '
+        r'PCollection, or PBegin. Got: {} instead.'.format(int))
+
+    with self.assertLogs(level='WARN') as log:
+      MyPTransform().get_type_hints()
+      self.assertIn(error_str, log.output[0])
+
+  def test_annotations_without_output_pcollection_wrapper(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: PCollection[int]) -> str:
+        return pcoll | Map(lambda num: str(num))
+
+    error_str = (
+        r'This output type hint will be ignored and not used for '
+        r'type-checking purposes. Typically, output type hints for a '
+        r'PTransform are single (or nested) types wrapped by a '
+        r'PCollection, PDone, or None. Got: {} instead.'.format(str))
+
+    with self.assertLogs(level='WARN') as log:
+      th = MyPTransform().get_type_hints()
+      self.assertIn(error_str, log.output[0])
+      self.assertEqual(th.input_types, ((int, ), {}))
+      self.assertEqual(th.output_types, None)
+
+  def test_annotations_without_input_internal_type(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: PCollection) -> PCollection[str]:
+        return pcoll | Map(lambda num: str(num))
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(th.input_types, ((Any, ), {}))
+    self.assertEqual(th.output_types, ((str, ), {}))
+
+  def test_annotations_without_output_internal_type(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: PCollection[int]) -> PCollection:
+        return pcoll | Map(lambda num: str(num))
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(th.input_types, ((int, ), {}))
+    self.assertEqual(th.output_types, ((Any, ), {}))
+
+  def test_annotations_without_any_internal_type(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: PCollection) -> PCollection:
+        return pcoll | Map(lambda num: str(num))
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(th.input_types, ((Any, ), {}))
+    self.assertEqual(th.output_types, ((Any, ), {}))
+
+  def test_annotations_without_input_typehint(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll) -> PCollection[str]:
+        return pcoll | Map(lambda num: str(num))
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(th.input_types, ((Any, ), {}))
+    self.assertEqual(th.output_types, ((str, ), {}))
+
+  def test_annotations_without_output_typehint(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: PCollection[int]):
+        return pcoll | Map(lambda num: str(num))
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(th.input_types, ((int, ), {}))
+    self.assertEqual(th.output_types, ((Any, ), {}))
+
+  def test_annotations_without_any_typehints(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll):
+        return pcoll | Map(lambda num: str(num))
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(th.input_types, None)
+    self.assertEqual(th.output_types, None)
+
+  def test_annotations_with_pbegin(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: PBegin):
+        return pcoll | Map(lambda num: str(num))
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(th.input_types, ((Any, ), {}))
+    self.assertEqual(th.output_types, ((Any, ), {}))
+
+  def test_annotations_with_pdone(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll) -> PDone:
+        return pcoll | Map(lambda num: str(num))
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(th.input_types, ((Any, ), {}))
+    self.assertEqual(th.output_types, ((Any, ), {}))
+
+  def test_annotations_with_none_input(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: None) -> PCollection[str]:
+        return pcoll | Map(lambda num: str(num))
+
+    error_str = (
+        r'This input type hint will be ignored and not used for '
+        r'type-checking purposes. Typically, input type hints for a '
+        r'PTransform are single (or nested) types wrapped by a '
+        r'PCollection, or PBegin. Got: {} instead.'.format(None))
+
+    with self.assertLogs(level='WARN') as log:
+      th = MyPTransform().get_type_hints()
+      self.assertIn(error_str, log.output[0])
+      self.assertEqual(th.input_types, None)
+      self.assertEqual(th.output_types, ((str, ), {}))
+
+  def test_annotations_with_none_output(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll) -> None:
+        return pcoll | Map(lambda num: str(num))
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(th.input_types, ((Any, ), {}))
+    self.assertEqual(th.output_types, ((Any, ), {}))
+
+  def test_annotations_with_arbitrary_output(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll) -> str:
+        return pcoll | Map(lambda num: str(num))
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(th.input_types, ((Any, ), {}))
+    self.assertEqual(th.output_types, None)
+
+  def test_annotations_with_arbitrary_input_and_output(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: int) -> str:
+        return pcoll | Map(lambda num: str(num))
+
+    input_error_str = (
+        r'This input type hint will be ignored and not used for '
+        r'type-checking purposes. Typically, input type hints for a '
+        r'PTransform are single (or nested) types wrapped by a '
+        r'PCollection, or PBegin. Got: {} instead.'.format(int))
+
+    output_error_str = (
+        r'This output type hint will be ignored and not used for '
+        r'type-checking purposes. Typically, output type hints for a '
+        r'PTransform are single (or nested) types wrapped by a '
+        r'PCollection, PDone, or None. Got: {} instead.'.format(str))
+
+    with self.assertLogs(level='WARN') as log:
+      th = MyPTransform().get_type_hints()
+      self.assertIn(input_error_str, log.output[0])
+      self.assertIn(output_error_str, log.output[1])
+      self.assertEqual(th.input_types, None)
+      self.assertEqual(th.output_types, None)
+
+  def test_typing_module_annotations_are_converted_to_beam_annotations(self):
+    class MyPTransform(PTransform):
+      def expand(
+          self, pcoll: PCollection[typing.Dict[str, str]]
+      ) -> PCollection[typing.Dict[str, str]]:
+        return pcoll
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(th.input_types, ((typehints.Dict[str, str], ), {}))
+    self.assertEqual(th.input_types, ((typehints.Dict[str, str], ), {}))
+
+  def test_nested_typing_annotations_are_converted_to_beam_annotations(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll:
+         PCollection[typing.Union[int, typing.Any, typing.Dict[str, float]]]) \
+      -> PCollection[typing.Union[int, typing.Any, typing.Dict[str, float]]]:
+        return pcoll
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(
+        th.input_types,
+        ((typehints.Union[int, typehints.Any, typehints.Dict[str,
+                                                             float]], ), {}))
+    self.assertEqual(
+        th.input_types,
+        ((typehints.Union[int, typehints.Any, typehints.Dict[str,
+                                                             float]], ), {}))
+
+  def test_mixed_annotations_are_converted_to_beam_annotations(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: typing.Any) -> typehints.Any:
+        return pcoll
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(th.input_types, ((typehints.Any, ), {}))
+    self.assertEqual(th.input_types, ((typehints.Any, ), {}))
+
+
 if __name__ == '__main__':
   unittest.main()
diff --git a/website/www/site/content/en/documentation/sdks/python-type-safety.md b/website/www/site/content/en/documentation/sdks/python-type-safety.md
index 074460c..755f795 100644
--- a/website/www/site/content/en/documentation/sdks/python-type-safety.md
+++ b/website/www/site/content/en/documentation/sdks/python-type-safety.md
@@ -71,7 +71,7 @@
 Using Annotations has the added benefit of allowing use of a static type checker (such as mypy) to additionally type check your code.
 If you already use a type checker, using annotations instead of decorators reduces code duplication.
 However, annotations do not cover all the use cases that decorators and inline declarations do.
-Two such are the `expand` of a composite transform and lambda functions.
+For instance, they do not work for lambda functions.
 
 ### Declaring Type Hints Using Type Annotations
 
@@ -82,6 +82,7 @@
 Annotations are currently supported on:
 
  - `process()` methods on `DoFn` subclasses.
+ - `expand()` methods on `PTransform` subclasses.
  - Functions passed to: `ParDo`, `Map`, `FlatMap`, `Filter`.
 
 The following code declares an `int` input and a `str` output type hint on the `to_id` transform, using annotations on `my_fn`.
@@ -90,6 +91,15 @@
 {{< code_sample "sdks/python/apache_beam/examples/snippets/snippets_test_py3.py" type_hints_map_annotations >}}
 {{< /highlight >}}
 
+The following code demonstrates how to use annotations on `PTransform` subclasses. 
+A valid annotation is a `PCollection` that wraps an internal (nested) type, `PBegin`, `PDone`, or `None`. 
+The following code declares typehints on a custom PTransform, that takes a `PCollection[int]` input
+and outputs a `PCollection[str]`, using annotations.
+
+{{< highlight py >}}
+{{< code_sample "sdks/python/apache_beam/examples/snippets/snippets_test_py3.py" type_hints_ptransforms >}}
+{{< /highlight >}}
+
 The following code declares `int` input and output type hints on `filter_evens`, using annotations on `FilterEvensDoFn.process`.
 Since `process` returns a generator, the output type for a DoFn producing a `PCollection[int]` is annotated as `Iterable[int]` (`Generator[int, None, None]` would also work here).
 Beam will remove the outer iterable of the return type on the `DoFn.process` method and functions passed to `FlatMap` to deduce the element type of resulting PCollection .
@@ -182,6 +192,7 @@
 * `Iterable[T]`
 * `Iterator[T]`
 * `Generator[T]`
+* `PCollection[T]`
 
 **Note:** The `Tuple[T, U]` type hint is a tuple with a fixed number of heterogeneously typed elements, while the `Tuple[T, ...]` type hint is a tuple with a variable of homogeneously typed elements.