blob: 4c485fe75fd36e0a8d6883a604d41e0547cdbb42 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Type hinting decorators allowing static or runtime type-checking for the SDK.
This module defines decorators which utilize the type-hints defined in
'type_hints.py' to allow annotation of the types of function arguments and
return values.
Type-hints for functions are annotated using two separate decorators. One is for
type-hinting the types of function arguments, the other for type-hinting the
function return value. Type-hints can either be specified in the form of
positional arguments::
@with_input_types(int, int)
def add(a, b):
return a + b
Keyword arguments::
@with_input_types(a=int, b=int)
def add(a, b):
return a + b
Or even a mix of both::
@with_input_types(int, b=int)
def add(a, b):
return a + b
Example usage for type-hinting arguments only::
@with_input_types(s=str)
def to_lower(a):
return a.lower()
Example usage for type-hinting return values only::
@with_output_types(Tuple[int, bool])
def compress_point(ec_point):
return ec_point.x, ec_point.y < 0
Example usage for type-hinting both arguments and return values::
@with_input_types(a=int)
@with_output_types(str)
def int_to_str(a):
return str(a)
Type-hinting a function with arguments that unpack tuples are also supported
(in Python 2 only). As an example, such a function would be defined as::
def foo((a, b)):
...
The valid type-hint for such as function looks like the following::
@with_input_types(a=int, b=int)
def foo((a, b)):
...
Notice that we hint the type of each unpacked argument independently, rather
than hinting the type of the tuple as a whole (Tuple[int, int]).
Optionally, type-hints can be type-checked at runtime. To toggle this behavior
this module defines two functions: 'enable_run_time_type_checking' and
'disable_run_time_type_checking'. NOTE: for this toggle behavior to work
properly it must appear at the top of the module where all functions are
defined, or before importing a module containing type-hinted functions.
"""
from __future__ import absolute_import
import inspect
import logging
import sys
import types
from builtins import next
from builtins import object
from builtins import zip
from apache_beam.typehints import native_type_compatibility
from apache_beam.typehints import typehints
from apache_beam.typehints.typehints import CompositeTypeHintError
from apache_beam.typehints.typehints import SimpleTypeHintError
from apache_beam.typehints.typehints import check_constraint
from apache_beam.typehints.typehints import validate_composite_type_param
try:
import funcsigs # Python 2 only.
except ImportError:
funcsigs = None
__all__ = [
'with_input_types',
'with_output_types',
'WithTypeHints',
'TypeCheckError',
]
# This is missing in the builtin types module. str.upper is arbitrary, any
# method on a C-implemented type will do.
# pylint: disable=invalid-name
_MethodDescriptorType = type(str.upper)
# pylint: enable=invalid-name
_ANY_VAR_POSITIONAL = typehints.Tuple[typehints.Any, ...]
_ANY_VAR_KEYWORD = typehints.Dict[typehints.Any, typehints.Any]
# TODO(BEAM-8280): Remove this when from_callable is ready to be enabled.
_enable_from_callable = False
try:
_original_getfullargspec = inspect.getfullargspec
_use_full_argspec = True
except AttributeError: # Python 2
_original_getfullargspec = inspect.getargspec
_use_full_argspec = False
def getfullargspec(func):
# Python 3: Use get_signature instead.
assert sys.version_info < (3,), 'This method should not be used in Python 3'
try:
return _original_getfullargspec(func)
except TypeError:
if isinstance(func, type):
argspec = getfullargspec(func.__init__)
del argspec.args[0]
return argspec
elif callable(func):
try:
return _original_getfullargspec(func.__call__)
except TypeError:
# Return an ArgSpec with at least one positional argument,
# and any number of other (positional or keyword) arguments
# whose name won't match any real argument.
# Arguments with the %unknown% prefix will be ignored in the type
# checking code.
if _use_full_argspec:
return inspect.FullArgSpec(
['_'], '__unknown__varargs', '__unknown__keywords', (),
[], {}, {})
else: # Python 2
return inspect.ArgSpec(
['_'], '__unknown__varargs', '__unknown__keywords', ())
else:
raise
def get_signature(func):
"""Like inspect.signature(), but supports Py2 as well.
This module uses inspect.signature instead of getfullargspec since in the
latter: 'the "self" parameter is always reported, even for bound methods'
https://github.com/python/cpython/blob/44f91c388a6f4da9ed3300df32ca290b8aa104ea/Lib/inspect.py#L1103
"""
# Fall back on funcsigs if inspect module doesn't have 'signature'; prefer
# inspect.signature over funcsigs.signature if both are available.
if hasattr(inspect, 'signature'):
inspect_ = inspect
else:
inspect_ = funcsigs
try:
signature = inspect_.signature(func)
except ValueError:
# Fall back on a catch-all signature.
params = [
inspect_.Parameter('_', inspect_.Parameter.POSITIONAL_OR_KEYWORD),
inspect_.Parameter('__unknown__varargs',
inspect_.Parameter.VAR_POSITIONAL),
inspect_.Parameter('__unknown__keywords',
inspect_.Parameter.VAR_KEYWORD)]
signature = inspect_.Signature(params)
# This is a specialization to hint the first argument of certain builtins,
# such as str.strip.
if isinstance(func, _MethodDescriptorType):
params = list(signature.parameters.values())
if params[0].annotation == params[0].empty:
params[0] = params[0].replace(annotation=func.__objclass__)
signature = signature.replace(parameters=params)
# This is a specialization to hint the return value of type callables.
if (signature.return_annotation == signature.empty and
isinstance(func, type)):
signature = signature.replace(return_annotation=typehints.normalize(func))
return signature
class IOTypeHints(object):
"""Encapsulates all type hint information about a Dataflow construct.
This should primarily be used via the WithTypeHints mixin class, though
may also be attached to other objects (such as Python functions).
Attributes:
input_types: (tuple, dict) List of typing types, and an optional dictionary.
May be None. The list and dict correspond to args and kwargs.
output_types: (tuple, dict) List of typing types, and an optional dictionary
(unused). Only the first element of the list is used. May be None.
"""
__slots__ = ('input_types', 'output_types')
def __init__(self, input_types=None, output_types=None):
self.input_types = input_types
self.output_types = output_types
@staticmethod
def from_callable(fn):
"""Construct an IOTypeHints object from a callable's signature.
Supports Python 3 annotations. For partial annotations, sets unknown types
to Any, _ANY_VAR_POSITIONAL, or _ANY_VAR_KEYWORD.
Returns:
A new IOTypeHints or None if no annotations found.
"""
if not _enable_from_callable:
return None
signature = get_signature(fn)
if (all(param.annotation == param.empty
for param in signature.parameters.values())
and signature.return_annotation == signature.empty):
return None
input_args = []
input_kwargs = {}
for param in signature.parameters.values():
if param.annotation == param.empty:
if param.kind == param.VAR_POSITIONAL:
input_args.append(_ANY_VAR_POSITIONAL)
elif param.kind == param.VAR_KEYWORD:
input_kwargs[param.name] = _ANY_VAR_KEYWORD
elif param.kind == param.KEYWORD_ONLY:
input_kwargs[param.name] = typehints.Any
else:
input_args.append(typehints.Any)
else:
if param.kind in [param.KEYWORD_ONLY, param.VAR_KEYWORD]:
input_kwargs[param.name] = param.annotation
else:
assert param.kind in [param.POSITIONAL_ONLY,
param.POSITIONAL_OR_KEYWORD,
param.VAR_POSITIONAL], \
'Unsupported Parameter kind: %s' % param.kind
input_args.append(param.annotation)
output_args = []
if signature.return_annotation != signature.empty:
output_args.append(signature.return_annotation)
else:
output_args.append(typehints.Any)
return IOTypeHints(input_types=(tuple(input_args), input_kwargs),
output_types=(tuple(output_args), {}))
def set_input_types(self, *args, **kwargs):
self.input_types = args, kwargs
def set_output_types(self, *args, **kwargs):
self.output_types = args, kwargs
def simple_output_type(self, context):
if self.output_types:
args, kwargs = self.output_types
if len(args) != 1 or kwargs:
raise TypeError(
'Expected single output type hint for %s but got: %s' % (
context, self.output_types))
return args[0]
def has_simple_output_type(self):
"""Whether there's a single positional output type."""
return (self.output_types and len(self.output_types[0]) == 1 and
not self.output_types[1])
def strip_iterable(self):
"""Removes outer Iterable (or equivalent) from output type.
Only affects instances with simple output types, otherwise is a no-op.
Example: Generator[Tuple(int, int)] becomes Tuple(int, int)
Raises:
ValueError if output type is simple and not iterable.
"""
if not self.has_simple_output_type():
return
yielded_type = typehints.get_yielded_type(self.output_types[0][0])
self.output_types = ((yielded_type,), {})
def copy(self):
return IOTypeHints(self.input_types, self.output_types)
def with_defaults(self, hints):
if not hints:
return self
if self._has_input_types():
input_types = self.input_types
else:
input_types = hints.input_types
if self._has_output_types():
output_types = self.output_types
else:
output_types = hints.output_types
return IOTypeHints(input_types, output_types)
def _has_input_types(self):
return self.input_types is not None and any(self.input_types)
def _has_output_types(self):
return self.output_types is not None and any(self.output_types)
def __bool__(self):
return self._has_input_types() or self._has_output_types()
def __repr__(self):
return 'IOTypeHints[inputs=%s, outputs=%s]' % (
self.input_types, self.output_types)
class WithTypeHints(object):
"""A mixin class that provides the ability to set and retrieve type hints.
"""
def __init__(self, *unused_args, **unused_kwargs):
self._type_hints = IOTypeHints()
def _get_or_create_type_hints(self):
# __init__ may have not been called
try:
return self._type_hints
except AttributeError:
self._type_hints = IOTypeHints()
return self._type_hints
def get_type_hints(self):
"""Gets and/or initializes type hints for this object.
If type hints have not been set, attempts to initialize type hints in this
order:
- Using self.default_type_hints().
- Using self.__class__ type hints.
"""
return (self._get_or_create_type_hints()
.with_defaults(self.default_type_hints())
.with_defaults(get_type_hints(self.__class__)))
def default_type_hints(self):
return None
def with_input_types(self, *arg_hints, **kwarg_hints):
arg_hints = native_type_compatibility.convert_to_beam_types(arg_hints)
kwarg_hints = native_type_compatibility.convert_to_beam_types(kwarg_hints)
self._get_or_create_type_hints().set_input_types(*arg_hints, **kwarg_hints)
return self
def with_output_types(self, *arg_hints, **kwarg_hints):
arg_hints = native_type_compatibility.convert_to_beam_types(arg_hints)
kwarg_hints = native_type_compatibility.convert_to_beam_types(kwarg_hints)
self._get_or_create_type_hints().set_output_types(*arg_hints, **kwarg_hints)
return self
class TypeCheckError(Exception):
pass
def _positional_arg_hints(arg, hints):
"""Returns the type of a (possibly tuple-packed) positional argument.
E.g. for lambda ((a, b), c): None the single positional argument is (as
returned by inspect) [[a, b], c] which should have type
Tuple[Tuple[Int, Any], float] when applied to the type hints
{a: int, b: Any, c: float}.
"""
if isinstance(arg, list):
return typehints.Tuple[[_positional_arg_hints(a, hints) for a in arg]]
return hints.get(arg, typehints.Any)
def _unpack_positional_arg_hints(arg, hint):
"""Unpacks the given hint according to the nested structure of arg.
For example, if arg is [[a, b], c] and hint is Tuple[Any, int], then
this function would return ((Any, Any), int) so it can be used in conjunction
with inspect.getcallargs.
"""
if isinstance(arg, list):
tuple_constraint = typehints.Tuple[[typehints.Any] * len(arg)]
if not typehints.is_consistent_with(hint, tuple_constraint):
raise TypeCheckError('Bad tuple arguments for %s: expected %s, got %s' %
(arg, tuple_constraint, hint))
if isinstance(hint, typehints.TupleConstraint):
return tuple(_unpack_positional_arg_hints(a, t)
for a, t in zip(arg, hint.tuple_types))
return (typehints.Any,) * len(arg)
return hint
def getcallargs_forhints(func, *typeargs, **typekwargs):
"""Like inspect.getcallargs, with support for declaring default args as Any.
In Python 2, understands that Tuple[] and an Any unpack.
Returns:
(Dict[str, Any]) A dictionary from arguments names to values.
"""
if sys.version_info < (3,):
return getcallargs_forhints_impl_py2(func, typeargs, typekwargs)
else:
return getcallargs_forhints_impl_py3(func, typeargs, typekwargs)
def getcallargs_forhints_impl_py2(func, typeargs, typekwargs):
argspec = getfullargspec(func)
# Turn Tuple[x, y] into (x, y) so getcallargs can do the proper unpacking.
packed_typeargs = [_unpack_positional_arg_hints(arg, hint)
for (arg, hint) in zip(argspec.args, typeargs)]
packed_typeargs += list(typeargs[len(packed_typeargs):])
# Monkeypatch inspect.getfullargspec to allow passing non-function objects.
# getfullargspec (getargspec on Python 2) are used by inspect.getcallargs.
# TODO(BEAM-5490): Reimplement getcallargs and stop relying on monkeypatch.
inspect.getargspec = getfullargspec
try:
callargs = inspect.getcallargs(func, *packed_typeargs, **typekwargs)
except TypeError as e:
raise TypeCheckError(e)
finally:
# Revert monkey-patch.
inspect.getargspec = _original_getfullargspec
if argspec.defaults:
# Declare any default arguments to be Any.
for k, var in enumerate(reversed(argspec.args)):
if k >= len(argspec.defaults):
break
if callargs.get(var, None) is argspec.defaults[-k-1]:
callargs[var] = typehints.Any
# Patch up varargs and keywords
if argspec.varargs:
# TODO(BEAM-8122): This will always assign _ANY_VAR_POSITIONAL. Should be
# "callargs.get(...) or _ANY_VAR_POSITIONAL".
callargs[argspec.varargs] = typekwargs.get(
argspec.varargs, _ANY_VAR_POSITIONAL)
varkw = argspec.keywords
if varkw:
# TODO(robertwb): Consider taking the union of key and value types.
callargs[varkw] = typekwargs.get(varkw, _ANY_VAR_KEYWORD)
# TODO(BEAM-5878) Support kwonlyargs.
return callargs
def _normalize_var_positional_hint(hint):
"""Converts a var_positional hint into Tuple[Union[<types>], ...] form.
Args:
hint: (tuple) Should be either a tuple of one or more types, or a single
Tuple[<type>, ...].
Raises:
TypeCheckError if hint does not have the right form.
"""
if not hint or type(hint) != tuple:
raise TypeCheckError('Unexpected VAR_POSITIONAL value: %s' % hint)
if len(hint) == 1 and isinstance(hint[0], typehints.TupleSequenceConstraint):
# Example: tuple(Tuple[Any, ...]) -> Tuple[Any, ...]
return hint[0]
else:
# Example: tuple(int, str) -> Tuple[Union[int, str], ...]
return typehints.Tuple[typehints.Union[hint], ...]
def _normalize_var_keyword_hint(hint, arg_name):
"""Converts a var_keyword hint into Dict[<key type>, <value type>] form.
Args:
hint: (dict) Should either contain a pair (arg_name,
Dict[<key type>, <value type>]), or one or more possible types for the
value.
arg_name: (str) The keyword receiving this hint.
Raises:
TypeCheckError if hint does not have the right form.
"""
if not hint or type(hint) != dict:
raise TypeCheckError('Unexpected VAR_KEYWORD value: %s' % hint)
keys = list(hint.keys())
values = list(hint.values())
if (len(values) == 1 and
keys[0] == arg_name and
isinstance(values[0], typehints.DictConstraint)):
# Example: dict(kwargs=Dict[str, Any]) -> Dict[str, Any]
return values[0]
else:
# Example: dict(k1=str, k2=int) -> Dict[str, Union[str,int]]
return typehints.Dict[str, typehints.Union[values]]
def getcallargs_forhints_impl_py3(func, type_args, type_kwargs):
"""Bind type_args and type_kwargs to func.
Works like inspect.getcallargs, with some modifications to support type hint
checks.
For unbound args, will use annotations and fall back to Any (or variants of
Any).
Returns:
A mapping from parameter name to argument.
"""
try:
signature = get_signature(func)
except ValueError as e:
logging.warning('Could not get signature for function: %s: %s', func, e)
return {}
try:
bindings = signature.bind(*type_args, **type_kwargs)
except TypeError as e:
# Might be raised due to too few or too many arguments.
raise TypeCheckError(e)
bound_args = bindings.arguments
for param in signature.parameters.values():
if param.name in bound_args:
# Bound: unpack/convert variadic arguments.
if param.kind == param.VAR_POSITIONAL:
bound_args[param.name] = _normalize_var_positional_hint(
bound_args[param.name])
elif param.kind == param.VAR_KEYWORD:
bound_args[param.name] = _normalize_var_keyword_hint(
bound_args[param.name], param.name)
else:
# Unbound: must have a default or be variadic.
if param.annotation != param.empty:
bound_args[param.name] = param.annotation
elif param.kind == param.VAR_POSITIONAL:
bound_args[param.name] = _ANY_VAR_POSITIONAL
elif param.kind == param.VAR_KEYWORD:
bound_args[param.name] = _ANY_VAR_KEYWORD
elif param.default is not param.empty:
# Declare unbound parameters with defaults to be Any.
bound_args[param.name] = typehints.Any
else:
# This case should be caught by signature.bind() above.
raise ValueError('Unexpected unbound parameter: %s' % param.name)
return dict(bound_args)
def get_type_hints(fn):
"""Gets the type hint associated with an arbitrary object fn.
Always returns a valid IOTypeHints object, creating one if necessary.
"""
# pylint: disable=protected-access
if not hasattr(fn, '_type_hints'):
try:
fn._type_hints = IOTypeHints()
except (AttributeError, TypeError):
# Can't add arbitrary attributes to this object,
# but might have some restrictions anyways...
hints = IOTypeHints()
# Python 3.7 introduces annotations for _MethodDescriptorTypes.
if isinstance(fn, _MethodDescriptorType) and sys.version_info < (3, 7):
hints.set_input_types(fn.__objclass__)
return hints
return fn._type_hints
# pylint: enable=protected-access
def with_input_types(*positional_hints, **keyword_hints):
"""A decorator that type-checks defined type-hints with passed func arguments.
All type-hinted arguments can be specified using positional arguments,
keyword arguments, or a mix of both. Additionaly, all function arguments must
be type-hinted in totality if even one parameter is type-hinted.
Once fully decorated, if the arguments passed to the resulting function
violate the type-hint constraints defined, a :class:`TypeCheckError`
detailing the error will be raised.
To be used as:
.. testcode::
from apache_beam.typehints import with_input_types
@with_input_types(str)
def upper(s):
return s.upper()
Or:
.. testcode::
from apache_beam.typehints import with_input_types
from apache_beam.typehints import List
from apache_beam.typehints import Tuple
@with_input_types(ls=List[Tuple[int, int]])
def increment(ls):
[(i + 1, j + 1) for (i,j) in ls]
Args:
*positional_hints: Positional type-hints having identical order as the
function's formal arguments. Values for this argument must either be a
built-in Python type or an instance of a
:class:`~apache_beam.typehints.typehints.TypeConstraint` created by
'indexing' a
:class:`~apache_beam.typehints.typehints.CompositeTypeHint` instance
with a type parameter.
**keyword_hints: Keyword arguments mirroring the names of the parameters to
the decorated functions. The value of each keyword argument must either
be one of the allowed built-in Python types, a custom class, or an
instance of a :class:`~apache_beam.typehints.typehints.TypeConstraint`
created by 'indexing' a
:class:`~apache_beam.typehints.typehints.CompositeTypeHint` instance
with a type parameter.
Raises:
:class:`~exceptions.ValueError`: If not all function arguments have
corresponding type-hints specified. Or if the inner wrapper function isn't
passed a function object.
:class:`TypeCheckError`: If the any of the passed type-hint
constraints are not a type or
:class:`~apache_beam.typehints.typehints.TypeConstraint` instance.
Returns:
The original function decorated such that it enforces type-hint constraints
for all received function arguments.
"""
converted_positional_hints = (
native_type_compatibility.convert_to_beam_types(positional_hints))
converted_keyword_hints = (
native_type_compatibility.convert_to_beam_types(keyword_hints))
del positional_hints
del keyword_hints
def annotate(f):
if isinstance(f, types.FunctionType):
for t in (list(converted_positional_hints) +
list(converted_keyword_hints.values())):
validate_composite_type_param(
t, error_msg_prefix='All type hint arguments')
get_type_hints(f).set_input_types(*converted_positional_hints,
**converted_keyword_hints)
return f
return annotate
def with_output_types(*return_type_hint, **kwargs):
"""A decorator that type-checks defined type-hints for return values(s).
This decorator will type-check the return value(s) of the decorated function.
Only a single type-hint is accepted to specify the return type of the return
value. If the function to be decorated has multiple return values, then one
should use: ``Tuple[type_1, type_2]`` to annotate the types of the return
values.
If the ultimate return value for the function violates the specified type-hint
a :class:`TypeCheckError` will be raised detailing the type-constraint
violation.
This decorator is intended to be used like:
.. testcode::
from apache_beam.typehints import with_output_types
from apache_beam.typehints import Set
class Coordinate(object):
def __init__(self, x, y):
self.x = x
self.y = y
@with_output_types(Set[Coordinate])
def parse_ints(ints):
return {Coordinate(i, i) for i in ints}
Or with a simple type-hint:
.. testcode::
from apache_beam.typehints import with_output_types
@with_output_types(bool)
def negate(p):
return not p if p else p
Args:
*return_type_hint: A type-hint specifying the proper return type of the
function. This argument should either be a built-in Python type or an
instance of a :class:`~apache_beam.typehints.typehints.TypeConstraint`
created by 'indexing' a
:class:`~apache_beam.typehints.typehints.CompositeTypeHint`.
**kwargs: Not used.
Raises:
:class:`~exceptions.ValueError`: If any kwarg parameters are passed in,
or the length of **return_type_hint** is greater than ``1``. Or if the
inner wrapper function isn't passed a function object.
:class:`TypeCheckError`: If the **return_type_hint** object is
in invalid type-hint.
Returns:
The original function decorated such that it enforces type-hint constraints
for all return values.
"""
if kwargs:
raise ValueError("All arguments for the 'returns' decorator must be "
"positional arguments.")
if len(return_type_hint) != 1:
raise ValueError("'returns' accepts only a single positional argument. In "
"order to specify multiple return types, use the 'Tuple' "
"type-hint.")
return_type_hint = native_type_compatibility.convert_to_beam_type(
return_type_hint[0])
validate_composite_type_param(
return_type_hint,
error_msg_prefix='All type hint arguments'
)
def annotate(f):
get_type_hints(f).set_output_types(return_type_hint)
return f
return annotate
def _check_instance_type(
type_constraint, instance, var_name=None, verbose=False):
"""A helper function to report type-hint constraint violations.
Args:
type_constraint: An instance of a 'TypeConstraint' or a built-in Python
type.
instance: The candidate object which will be checked by to satisfy
'type_constraint'.
var_name: If 'instance' is an argument, then the actual name for the
parameter in the original function definition.
Raises:
TypeCheckError: If 'instance' fails to meet the type-constraint of
'type_constraint'.
"""
hint_type = (
"argument: '%s'" % var_name if var_name is not None else 'return type')
try:
check_constraint(type_constraint, instance)
except SimpleTypeHintError:
if verbose:
verbose_instance = '%s, ' % instance
else:
verbose_instance = ''
raise TypeCheckError('Type-hint for %s violated. Expected an '
'instance of %s, instead found %san instance of %s.'
% (hint_type, type_constraint,
verbose_instance, type(instance)))
except CompositeTypeHintError as e:
raise TypeCheckError('Type-hint for %s violated: %s' % (hint_type, e))
def _interleave_type_check(type_constraint, var_name=None):
"""Lazily type-check the type-hint for a lazily generated sequence type.
This function can be applied as a decorator or called manually in a curried
manner:
* @_interleave_type_check(List[int])
def gen():
yield 5
or
* gen = _interleave_type_check(Tuple[int, int], 'coord_gen')(gen)
As a result, all type-checking for the passed generator will occur at 'yield'
time. This way, we avoid having to depleat the generator in order to
type-check it.
Args:
type_constraint: An instance of a TypeConstraint. The output yielded of
'gen' will be type-checked according to this type constraint.
var_name: The variable name binded to 'gen' if type-checking a function
argument. Used solely for templating in error message generation.
Returns:
A function which takes a generator as an argument and returns a wrapped
version of the generator that interleaves type-checking at 'yield'
iteration. If the generator received is already wrapped, then it is simply
returned to avoid nested wrapping.
"""
def wrapper(gen):
if isinstance(gen, GeneratorWrapper):
return gen
return GeneratorWrapper(
gen,
lambda x: _check_instance_type(type_constraint, x, var_name)
)
return wrapper
class GeneratorWrapper(object):
"""A wrapper around a generator, allows execution of a callback per yield.
Additionally, wrapping a generator with this class allows one to assign
arbitary attributes to a generator object just as with a function object.
Attributes:
internal_gen: A instance of a generator object. As part of 'step' of the
generator, the yielded object will be passed to 'interleave_func'.
interleave_func: A callback accepting a single argument. This function will
be called with the result of each yielded 'step' in the internal
generator.
"""
def __init__(self, gen, interleave_func):
self.internal_gen = gen
self.interleave_func = interleave_func
def __getattr__(self, attr):
# TODO(laolu): May also want to intercept 'send' in the future if we move to
# a GeneratorHint with 3 type-params:
# * Generator[send_type, return_type, yield_type]
if attr == '__next__':
return self.__next__()
elif attr == '__iter__':
return self.__iter__()
return getattr(self.internal_gen, attr)
def __next__(self):
next_val = next(self.internal_gen)
self.interleave_func(next_val)
return next_val
next = __next__
def __iter__(self):
for x in self.internal_gen:
self.interleave_func(x)
yield x