[BEAM-10258] Support type hint annotations on PTransform's expand() (#12009)
* [BEAM-10258] Support type hint annotations on PTransform's expand()
* Fixup: apply YAPF
* Moving PCollectionTypeConstraint to typehints.py
* Uses Generic[T] instead of PCollectionTypeConstraint
* Fixup: apply YAPF
* Remove unused imports
* Force user to wrap typehints in PCollections
* Add unit tests for various usages of typehints on PTransforms
* Add tests that use typehints on real pipelines
* Fixup: apply YAPF
* Fix bad merge
* Support PDone, PBegin, and better handling of error cases
* Fix test syntax
* Refactors strip_pcoll_input() and strip_pcoll_output() to a shared function
* Add unit tests
* Add more tests
* Add website documentation
* Fix linting issues
* Fix linting issue by using multi-line function annotations
* Fix more lint errors
* Fix import order, and other changes for PR
* Fix ungrouped-imports error
* Alphabetically order the imports
* Fixup: apply YAPF
* Fixes a bug where a type can have an empty __args__ attribute
* Fix bug in website snippet code
* Fixup: apply YAPF
* Fixup: apply YAPF
* Fix NoneType error
* Fix NoneType error part 2
* Use classes instead of strings during typecheck, and add tests
* Resolve circular import error and fix readability issues
* Fix lint errors
* Add back accidentally removed test
* Support None as an output annotation
* Show incorrect type in error message
Co-authored-by: Udi Meiri <udim@users.noreply.github.com>
* Allow Pipeline as an input
* Fix import bug
* Alphabetically order imports inside function (but really this is just to force re-run the tests)
* Display warning instead of throwing error for oddly formed type hints
* Convert to Beam types
* Add test for generic TypeVars
* Fix bug by skipping DoOutputsTuple
* Fix typo
* Add test for DoOutputsTuple
* Fix lint errors
Co-authored-by: Udi Meiri <udim@users.noreply.github.com>
diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test_py3.py b/sdks/python/apache_beam/examples/snippets/snippets_test_py3.py
index 0f0b668..5eb1d4b 100644
--- a/sdks/python/apache_beam/examples/snippets/snippets_test_py3.py
+++ b/sdks/python/apache_beam/examples/snippets/snippets_test_py3.py
@@ -96,6 +96,18 @@
ids = numbers | 'to_id' >> beam.Map(my_fn)
# [END type_hints_map_annotations]
+ # Example using an annotated PTransform.
+ with self.assertRaises(typehints.TypeCheckError):
+ # [START type_hints_ptransforms]
+ from apache_beam.pvalue import PCollection
+
+ class IntToStr(beam.PTransform):
+ def expand(self, pcoll: PCollection[int]) -> PCollection[str]:
+ return pcoll | beam.Map(lambda elem: str(elem))
+
+ ids = numbers | 'convert to str' >> IntToStr()
+ # [END type_hints_ptransforms]
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py
index 2ab121e..4a0e80f 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -68,13 +68,16 @@
from apache_beam.internal import pickler
from apache_beam.internal import util
from apache_beam.portability import python_urns
+from apache_beam.pvalue import DoOutputsTuple
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.transforms.display import HasDisplayData
from apache_beam.typehints import native_type_compatibility
from apache_beam.typehints import typehints
+from apache_beam.typehints.decorators import IOTypeHints
from apache_beam.typehints.decorators import TypeCheckError
from apache_beam.typehints.decorators import WithTypeHints
from apache_beam.typehints.decorators import get_signature
+from apache_beam.typehints.decorators import get_type_hints
from apache_beam.typehints.decorators import getcallargs_forhints
from apache_beam.typehints.trivial_inference import instance_to_type
from apache_beam.typehints.typehints import validate_composite_type_param
@@ -350,6 +353,14 @@
# type: () -> str
return self.__class__.__name__
+ def default_type_hints(self):
+ fn_type_hints = IOTypeHints.from_callable(self.expand)
+ if fn_type_hints is not None:
+ fn_type_hints = fn_type_hints.strip_pcoll()
+
+ # Prefer class decorator type hints for backwards compatibility.
+ return get_type_hints(self.__class__).with_defaults(fn_type_hints)
+
def with_input_types(self, input_type_hint):
"""Annotates the input type of a :class:`PTransform` with a type-hint.
@@ -419,6 +430,8 @@
root_hint = (
arg_hints[0] if len(arg_hints) == 1 else arg_hints or kwarg_hints)
for context, pvalue_, hint in _ZipPValues().visit(pvalueish, root_hint):
+ if isinstance(pvalue_, DoOutputsTuple):
+ continue
if pvalue_.element_type is None:
# TODO(robertwb): It's a bug that we ever get here. (typecheck)
continue
diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py
index dad0a31..4cd7681 100644
--- a/sdks/python/apache_beam/typehints/decorators.py
+++ b/sdks/python/apache_beam/typehints/decorators.py
@@ -105,6 +105,7 @@
from typing import Optional
from typing import Tuple
from typing import TypeVar
+from typing import Union
from apache_beam.typehints import native_type_compatibility
from apache_beam.typehints import typehints
@@ -378,6 +379,75 @@
self.output_types and len(self.output_types[0]) == 1 and
not self.output_types[1])
+ def strip_pcoll(self):
+ from apache_beam.pipeline import Pipeline
+ from apache_beam.pvalue import PBegin
+ from apache_beam.pvalue import PDone
+
+ return self.strip_pcoll_helper(self.input_types,
+ self._has_input_types,
+ 'input_types',
+ [Pipeline, PBegin],
+ 'This input type hint will be ignored '
+ 'and not used for type-checking purposes. '
+ 'Typically, input type hints for a '
+ 'PTransform are single (or nested) types '
+ 'wrapped by a PCollection, or PBegin.',
+ 'strip_pcoll_input()').\
+ strip_pcoll_helper(self.output_types,
+ self.has_simple_output_type,
+ 'output_types',
+ [PDone, None],
+ 'This output type hint will be ignored '
+ 'and not used for type-checking purposes. '
+ 'Typically, output type hints for a '
+ 'PTransform are single (or nested) types '
+ 'wrapped by a PCollection, PDone, or None.',
+ 'strip_pcoll_output()')
+
+ def strip_pcoll_helper(
+ self,
+ my_type, # type: any
+ has_my_type, # type: Callable[[], bool]
+ my_key, # type: str
+ special_containers, # type: List[Union[PBegin, PDone, PCollection]]
+ error_str, # type: str
+ source_str # type: str
+ ):
+ # type: (...) -> IOTypeHints
+ from apache_beam.pvalue import PCollection
+
+ if not has_my_type() or not my_type or len(my_type[0]) != 1:
+ return self
+
+ my_type = my_type[0][0]
+
+ if isinstance(my_type, typehints.AnyTypeConstraint):
+ return self
+
+ special_containers += [PCollection]
+ kwarg_dict = {}
+
+ if (my_type not in special_containers and
+ getattr(my_type, '__origin__', None) != PCollection):
+ logging.warning(error_str + ' Got: %s instead.' % my_type)
+ kwarg_dict[my_key] = None
+ return self._replace(
+ origin=self._make_origin([self], tb=False, msg=[source_str]),
+ **kwarg_dict)
+
+ if (getattr(my_type, '__args__', -1) in [-1, None] or
+ len(my_type.__args__) == 0):
+ # e.g. PCollection (or PBegin/PDone)
+ kwarg_dict[my_key] = ((typehints.Any, ), {})
+ else:
+ # e.g. PCollection[type]
+ kwarg_dict[my_key] = ((convert_to_beam_type(my_type.__args__[0]), ), {})
+
+ return self._replace(
+ origin=self._make_origin([self], tb=False, msg=[source_str]),
+ **kwarg_dict)
+
def strip_iterable(self):
# type: () -> IOTypeHints
diff --git a/sdks/python/apache_beam/typehints/typed_pipeline_test_py3.py b/sdks/python/apache_beam/typehints/typed_pipeline_test_py3.py
index 2016871..e12930d 100644
--- a/sdks/python/apache_beam/typehints/typed_pipeline_test_py3.py
+++ b/sdks/python/apache_beam/typehints/typed_pipeline_test_py3.py
@@ -22,6 +22,7 @@
from __future__ import absolute_import
+import typing
import unittest
import apache_beam as beam
@@ -257,6 +258,135 @@
result = [1, 2, 3] | beam.FlatMap(fn) | beam.Map(fn2)
self.assertCountEqual([4, 6], result)
+ def test_typed_ptransform_with_no_error(self):
+ class StrToInt(beam.PTransform):
+ def expand(
+ self,
+ pcoll: beam.pvalue.PCollection[str]) -> beam.pvalue.PCollection[int]:
+ return pcoll | beam.Map(lambda x: int(x))
+
+ class IntToStr(beam.PTransform):
+ def expand(
+ self,
+ pcoll: beam.pvalue.PCollection[int]) -> beam.pvalue.PCollection[str]:
+ return pcoll | beam.Map(lambda x: str(x))
+
+ _ = ['1', '2', '3'] | StrToInt() | IntToStr()
+
+ def test_typed_ptransform_with_bad_typehints(self):
+ class StrToInt(beam.PTransform):
+ def expand(
+ self,
+ pcoll: beam.pvalue.PCollection[str]) -> beam.pvalue.PCollection[int]:
+ return pcoll | beam.Map(lambda x: int(x))
+
+ class IntToStr(beam.PTransform):
+ def expand(
+ self,
+ pcoll: beam.pvalue.PCollection[str]) -> beam.pvalue.PCollection[str]:
+ return pcoll | beam.Map(lambda x: str(x))
+
+ with self.assertRaisesRegex(typehints.TypeCheckError,
+ "Input type hint violation at IntToStr: "
+ "expected <class 'str'>, got <class 'int'>"):
+ _ = ['1', '2', '3'] | StrToInt() | IntToStr()
+
+ def test_typed_ptransform_with_bad_input(self):
+ class StrToInt(beam.PTransform):
+ def expand(
+ self,
+ pcoll: beam.pvalue.PCollection[str]) -> beam.pvalue.PCollection[int]:
+ return pcoll | beam.Map(lambda x: int(x))
+
+ class IntToStr(beam.PTransform):
+ def expand(
+ self,
+ pcoll: beam.pvalue.PCollection[int]) -> beam.pvalue.PCollection[str]:
+ return pcoll | beam.Map(lambda x: str(x))
+
+ with self.assertRaisesRegex(typehints.TypeCheckError,
+ "Input type hint violation at StrToInt: "
+ "expected <class 'str'>, got <class 'int'>"):
+ # Feed integers to a PTransform that expects strings
+ _ = [1, 2, 3] | StrToInt() | IntToStr()
+
+ def test_typed_ptransform_with_partial_typehints(self):
+ class StrToInt(beam.PTransform):
+ def expand(self, pcoll) -> beam.pvalue.PCollection[int]:
+ return pcoll | beam.Map(lambda x: int(x))
+
+ class IntToStr(beam.PTransform):
+ def expand(
+ self,
+ pcoll: beam.pvalue.PCollection[int]) -> beam.pvalue.PCollection[str]:
+ return pcoll | beam.Map(lambda x: str(x))
+
+ # Feed integers to a PTransform that should expect strings
+ # but has no typehints so it expects any
+ _ = [1, 2, 3] | StrToInt() | IntToStr()
+
+ def test_typed_ptransform_with_bare_wrappers(self):
+ class StrToInt(beam.PTransform):
+ def expand(
+ self, pcoll: beam.pvalue.PCollection) -> beam.pvalue.PCollection:
+ return pcoll | beam.Map(lambda x: int(x))
+
+ class IntToStr(beam.PTransform):
+ def expand(
+ self,
+ pcoll: beam.pvalue.PCollection[int]) -> beam.pvalue.PCollection[str]:
+ return pcoll | beam.Map(lambda x: str(x))
+
+ _ = [1, 2, 3] | StrToInt() | IntToStr()
+
+ def test_typed_ptransform_with_no_typehints(self):
+ class StrToInt(beam.PTransform):
+ def expand(self, pcoll):
+ return pcoll | beam.Map(lambda x: int(x))
+
+ class IntToStr(beam.PTransform):
+ def expand(
+ self,
+ pcoll: beam.pvalue.PCollection[int]) -> beam.pvalue.PCollection[str]:
+ return pcoll | beam.Map(lambda x: str(x))
+
+ # Feed integers to a PTransform that should expect strings
+ # but has no typehints so it expects any
+ _ = [1, 2, 3] | StrToInt() | IntToStr()
+
+ def test_typed_ptransform_with_generic_annotations(self):
+ T = typing.TypeVar('T')
+
+ class IntToInt(beam.PTransform):
+ def expand(
+ self,
+ pcoll: beam.pvalue.PCollection[T]) -> beam.pvalue.PCollection[T]:
+ return pcoll | beam.Map(lambda x: x)
+
+ class IntToStr(beam.PTransform):
+ def expand(
+ self,
+ pcoll: beam.pvalue.PCollection[T]) -> beam.pvalue.PCollection[str]:
+ return pcoll | beam.Map(lambda x: str(x))
+
+ _ = [1, 2, 3] | IntToInt() | IntToStr()
+
+ def test_typed_ptransform_with_do_outputs_tuple_compiles(self):
+ class MyDoFn(beam.DoFn):
+ def process(self, element: int, *args, **kwargs):
+ if element % 2:
+ yield beam.pvalue.TaggedOutput('odd', 1)
+ else:
+ yield beam.pvalue.TaggedOutput('even', 1)
+
+ class MyPTransform(beam.PTransform):
+ def expand(self, pcoll: beam.pvalue.PCollection[int]):
+ return pcoll | beam.ParDo(MyDoFn()).with_outputs('odd', 'even')
+
+ # This test fails if you remove the following line from ptransform.py
+ # if isinstance(pvalue_, DoOutputsTuple): continue
+ _ = [1, 2, 3] | MyPTransform()
+
class AnnotationsTest(unittest.TestCase):
def test_pardo_dofn(self):
diff --git a/sdks/python/apache_beam/typehints/typehints_test_py3.py b/sdks/python/apache_beam/typehints/typehints_test_py3.py
index a7c23f0..5a36330 100644
--- a/sdks/python/apache_beam/typehints/typehints_test_py3.py
+++ b/sdks/python/apache_beam/typehints/typehints_test_py3.py
@@ -23,11 +23,19 @@
from __future__ import absolute_import
from __future__ import print_function
+import typing
import unittest
+import apache_beam.typehints.typehints as typehints
+from apache_beam import Map
+from apache_beam import PTransform
+from apache_beam.pvalue import PBegin
+from apache_beam.pvalue import PCollection
+from apache_beam.pvalue import PDone
from apache_beam.transforms.core import DoFn
from apache_beam.typehints import KV
from apache_beam.typehints import Iterable
+from apache_beam.typehints.typehints import Any
class TestParDoAnnotations(unittest.TestCase):
@@ -46,11 +54,221 @@
def process(self, element: int) -> Iterable[str]:
pass
- print(MyDoFn().get_type_hints())
th = MyDoFn().get_type_hints()
self.assertEqual(th.input_types, ((int, ), {}))
self.assertEqual(th.output_types, ((str, ), {}))
+class TestPTransformAnnotations(unittest.TestCase):
+ def test_pep484_annotations(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll: PCollection[int]) -> PCollection[str]:
+ return pcoll | Map(lambda num: str(num))
+
+ th = MyPTransform().get_type_hints()
+ self.assertEqual(th.input_types, ((int, ), {}))
+ self.assertEqual(th.output_types, ((str, ), {}))
+
+ def test_annotations_without_input_pcollection_wrapper(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll: int) -> PCollection[str]:
+ return pcoll | Map(lambda num: str(num))
+
+ error_str = (
+ r'This input type hint will be ignored and not used for '
+ r'type-checking purposes. Typically, input type hints for a '
+ r'PTransform are single (or nested) types wrapped by a '
+ r'PCollection, or PBegin. Got: {} instead.'.format(int))
+
+ with self.assertLogs(level='WARN') as log:
+ MyPTransform().get_type_hints()
+ self.assertIn(error_str, log.output[0])
+
+ def test_annotations_without_output_pcollection_wrapper(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll: PCollection[int]) -> str:
+ return pcoll | Map(lambda num: str(num))
+
+ error_str = (
+ r'This output type hint will be ignored and not used for '
+ r'type-checking purposes. Typically, output type hints for a '
+ r'PTransform are single (or nested) types wrapped by a '
+ r'PCollection, PDone, or None. Got: {} instead.'.format(str))
+
+ with self.assertLogs(level='WARN') as log:
+ th = MyPTransform().get_type_hints()
+ self.assertIn(error_str, log.output[0])
+ self.assertEqual(th.input_types, ((int, ), {}))
+ self.assertEqual(th.output_types, None)
+
+ def test_annotations_without_input_internal_type(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll: PCollection) -> PCollection[str]:
+ return pcoll | Map(lambda num: str(num))
+
+ th = MyPTransform().get_type_hints()
+ self.assertEqual(th.input_types, ((Any, ), {}))
+ self.assertEqual(th.output_types, ((str, ), {}))
+
+ def test_annotations_without_output_internal_type(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll: PCollection[int]) -> PCollection:
+ return pcoll | Map(lambda num: str(num))
+
+ th = MyPTransform().get_type_hints()
+ self.assertEqual(th.input_types, ((int, ), {}))
+ self.assertEqual(th.output_types, ((Any, ), {}))
+
+ def test_annotations_without_any_internal_type(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll: PCollection) -> PCollection:
+ return pcoll | Map(lambda num: str(num))
+
+ th = MyPTransform().get_type_hints()
+ self.assertEqual(th.input_types, ((Any, ), {}))
+ self.assertEqual(th.output_types, ((Any, ), {}))
+
+ def test_annotations_without_input_typehint(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll) -> PCollection[str]:
+ return pcoll | Map(lambda num: str(num))
+
+ th = MyPTransform().get_type_hints()
+ self.assertEqual(th.input_types, ((Any, ), {}))
+ self.assertEqual(th.output_types, ((str, ), {}))
+
+ def test_annotations_without_output_typehint(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll: PCollection[int]):
+ return pcoll | Map(lambda num: str(num))
+
+ th = MyPTransform().get_type_hints()
+ self.assertEqual(th.input_types, ((int, ), {}))
+ self.assertEqual(th.output_types, ((Any, ), {}))
+
+ def test_annotations_without_any_typehints(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll):
+ return pcoll | Map(lambda num: str(num))
+
+ th = MyPTransform().get_type_hints()
+ self.assertEqual(th.input_types, None)
+ self.assertEqual(th.output_types, None)
+
+ def test_annotations_with_pbegin(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll: PBegin):
+ return pcoll | Map(lambda num: str(num))
+
+ th = MyPTransform().get_type_hints()
+ self.assertEqual(th.input_types, ((Any, ), {}))
+ self.assertEqual(th.output_types, ((Any, ), {}))
+
+ def test_annotations_with_pdone(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll) -> PDone:
+ return pcoll | Map(lambda num: str(num))
+
+ th = MyPTransform().get_type_hints()
+ self.assertEqual(th.input_types, ((Any, ), {}))
+ self.assertEqual(th.output_types, ((Any, ), {}))
+
+ def test_annotations_with_none_input(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll: None) -> PCollection[str]:
+ return pcoll | Map(lambda num: str(num))
+
+ error_str = (
+ r'This input type hint will be ignored and not used for '
+ r'type-checking purposes. Typically, input type hints for a '
+ r'PTransform are single (or nested) types wrapped by a '
+ r'PCollection, or PBegin. Got: {} instead.'.format(None))
+
+ with self.assertLogs(level='WARN') as log:
+ th = MyPTransform().get_type_hints()
+ self.assertIn(error_str, log.output[0])
+ self.assertEqual(th.input_types, None)
+ self.assertEqual(th.output_types, ((str, ), {}))
+
+ def test_annotations_with_none_output(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll) -> None:
+ return pcoll | Map(lambda num: str(num))
+
+ th = MyPTransform().get_type_hints()
+ self.assertEqual(th.input_types, ((Any, ), {}))
+ self.assertEqual(th.output_types, ((Any, ), {}))
+
+ def test_annotations_with_arbitrary_output(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll) -> str:
+ return pcoll | Map(lambda num: str(num))
+
+ th = MyPTransform().get_type_hints()
+ self.assertEqual(th.input_types, ((Any, ), {}))
+ self.assertEqual(th.output_types, None)
+
+ def test_annotations_with_arbitrary_input_and_output(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll: int) -> str:
+ return pcoll | Map(lambda num: str(num))
+
+ input_error_str = (
+ r'This input type hint will be ignored and not used for '
+ r'type-checking purposes. Typically, input type hints for a '
+ r'PTransform are single (or nested) types wrapped by a '
+ r'PCollection, or PBegin. Got: {} instead.'.format(int))
+
+ output_error_str = (
+ r'This output type hint will be ignored and not used for '
+ r'type-checking purposes. Typically, output type hints for a '
+ r'PTransform are single (or nested) types wrapped by a '
+ r'PCollection, PDone, or None. Got: {} instead.'.format(str))
+
+ with self.assertLogs(level='WARN') as log:
+ th = MyPTransform().get_type_hints()
+ self.assertIn(input_error_str, log.output[0])
+ self.assertIn(output_error_str, log.output[1])
+ self.assertEqual(th.input_types, None)
+ self.assertEqual(th.output_types, None)
+
+ def test_typing_module_annotations_are_converted_to_beam_annotations(self):
+ class MyPTransform(PTransform):
+ def expand(
+ self, pcoll: PCollection[typing.Dict[str, str]]
+ ) -> PCollection[typing.Dict[str, str]]:
+ return pcoll
+
+ th = MyPTransform().get_type_hints()
+ self.assertEqual(th.input_types, ((typehints.Dict[str, str], ), {}))
+ self.assertEqual(th.input_types, ((typehints.Dict[str, str], ), {}))
+
+ def test_nested_typing_annotations_are_converted_to_beam_annotations(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll:
+ PCollection[typing.Union[int, typing.Any, typing.Dict[str, float]]]) \
+ -> PCollection[typing.Union[int, typing.Any, typing.Dict[str, float]]]:
+ return pcoll
+
+ th = MyPTransform().get_type_hints()
+ self.assertEqual(
+ th.input_types,
+ ((typehints.Union[int, typehints.Any, typehints.Dict[str,
+ float]], ), {}))
+ self.assertEqual(
+ th.input_types,
+ ((typehints.Union[int, typehints.Any, typehints.Dict[str,
+ float]], ), {}))
+
+ def test_mixed_annotations_are_converted_to_beam_annotations(self):
+ class MyPTransform(PTransform):
+ def expand(self, pcoll: typing.Any) -> typehints.Any:
+ return pcoll
+
+ th = MyPTransform().get_type_hints()
+ self.assertEqual(th.input_types, ((typehints.Any, ), {}))
+ self.assertEqual(th.input_types, ((typehints.Any, ), {}))
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/website/www/site/content/en/documentation/sdks/python-type-safety.md b/website/www/site/content/en/documentation/sdks/python-type-safety.md
index 074460c..755f795 100644
--- a/website/www/site/content/en/documentation/sdks/python-type-safety.md
+++ b/website/www/site/content/en/documentation/sdks/python-type-safety.md
@@ -71,7 +71,7 @@
Using Annotations has the added benefit of allowing use of a static type checker (such as mypy) to additionally type check your code.
If you already use a type checker, using annotations instead of decorators reduces code duplication.
However, annotations do not cover all the use cases that decorators and inline declarations do.
-Two such are the `expand` of a composite transform and lambda functions.
+For instance, they do not work for lambda functions.
### Declaring Type Hints Using Type Annotations
@@ -82,6 +82,7 @@
Annotations are currently supported on:
- `process()` methods on `DoFn` subclasses.
+ - `expand()` methods on `PTransform` subclasses.
- Functions passed to: `ParDo`, `Map`, `FlatMap`, `Filter`.
The following code declares an `int` input and a `str` output type hint on the `to_id` transform, using annotations on `my_fn`.
@@ -90,6 +91,15 @@
{{< code_sample "sdks/python/apache_beam/examples/snippets/snippets_test_py3.py" type_hints_map_annotations >}}
{{< /highlight >}}
+The following code demonstrates how to use annotations on `PTransform` subclasses.
+A valid annotation is a `PCollection` that wraps an internal (nested) type, `PBegin`, `PDone`, or `None`.
+The following code declares typehints on a custom PTransform, that takes a `PCollection[int]` input
+and outputs a `PCollection[str]`, using annotations.
+
+{{< highlight py >}}
+{{< code_sample "sdks/python/apache_beam/examples/snippets/snippets_test_py3.py" type_hints_ptransforms >}}
+{{< /highlight >}}
+
The following code declares `int` input and output type hints on `filter_evens`, using annotations on `FilterEvensDoFn.process`.
Since `process` returns a generator, the output type for a DoFn producing a `PCollection[int]` is annotated as `Iterable[int]` (`Generator[int, None, None]` would also work here).
Beam will remove the outer iterable of the return type on the `DoFn.process` method and functions passed to `FlatMap` to deduce the element type of resulting PCollection .
@@ -182,6 +192,7 @@
* `Iterable[T]`
* `Iterator[T]`
* `Generator[T]`
+* `PCollection[T]`
**Note:** The `Tuple[T, U]` type hint is a tuple with a fixed number of heterogeneously typed elements, while the `Tuple[T, ...]` type hint is a tuple with a variable of homogeneously typed elements.