blob: 9b109969bf3f9b42093a1af9abebb036dc0b567e [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for decorators module."""
# pytype: skip-file
import functools
import sys
import typing
import unittest
from apache_beam import Map
from apache_beam.typehints import Any
from apache_beam.typehints import Dict
from apache_beam.typehints import List
from apache_beam.typehints import Tuple
from apache_beam.typehints import TypeCheckError
from apache_beam.typehints import TypeVariable
from apache_beam.typehints import WithTypeHints
from apache_beam.typehints import decorators
from apache_beam.typehints import typehints
T = TypeVariable('T')
# Name is 'T' so it converts to a beam type with the same name.
# mypy requires that the name of the variable match, so we must ignore this.
T_typing = typing.TypeVar('T') # type: ignore
class IOTypeHintsTest(unittest.TestCase):
def test_get_signature(self):
# Basic coverage only to make sure function works.
def fn(a, b=1, *c, **d):
return a, b, c, d
s = decorators.get_signature(fn)
self.assertListEqual(list(s.parameters), ['a', 'b', 'c', 'd'])
def test_get_signature_builtin(self):
# Tests a builtin function for 3.7+ and fallback result for older versions.
s = decorators.get_signature(list)
if sys.version_info < (3, 7):
self.assertListEqual(
list(s.parameters),
['_', '__unknown__varargs', '__unknown__keywords'])
else:
self.assertListEqual(list(s.parameters), ['iterable'])
self.assertEqual(s.return_annotation, List[Any])
def test_from_callable_without_annotations(self):
def fn(a, b=None, *args, **kwargs):
return a, b, args, kwargs
th = decorators.IOTypeHints.from_callable(fn)
self.assertIsNone(th)
def test_from_callable_builtin(self):
th = decorators.IOTypeHints.from_callable(len)
self.assertIsNone(th)
def test_from_callable_method_descriptor(self):
# from_callable() injects an annotation in this special type of builtin.
th = decorators.IOTypeHints.from_callable(str.strip)
if sys.version_info >= (3, 7):
self.assertEqual(th.input_types, ((str, Any), {}))
else:
self.assertEqual(
th.input_types,
((str, decorators._ANY_VAR_POSITIONAL), {
'__unknown__keywords': decorators._ANY_VAR_KEYWORD
}))
self.assertEqual(th.output_types, ((Any, ), {}))
def test_strip_iterable_not_simple_output_noop(self):
th = decorators.IOTypeHints(
input_types=None, output_types=((int, str), {}), origin=[])
th = th.strip_iterable()
self.assertEqual(((int, str), {}), th.output_types)
def _test_strip_iterable(self, before, expected_after):
th = decorators.IOTypeHints(
input_types=None, output_types=((before, ), {}), origin=[])
after = th.strip_iterable()
self.assertEqual(((expected_after, ), {}), after.output_types)
def _test_strip_iterable_fail(self, before):
with self.assertRaisesRegex(ValueError, r'not iterable'):
self._test_strip_iterable(before, None)
def test_strip_iterable(self):
self._test_strip_iterable(None, None)
self._test_strip_iterable(typehints.Any, typehints.Any)
self._test_strip_iterable(typehints.Iterable[str], str)
self._test_strip_iterable(typehints.List[str], str)
self._test_strip_iterable(typehints.Iterator[str], str)
self._test_strip_iterable(typehints.Generator[str], str)
self._test_strip_iterable(typehints.Tuple[str], str)
self._test_strip_iterable(
typehints.Tuple[str, int], typehints.Union[str, int])
self._test_strip_iterable(typehints.Tuple[str, ...], str)
self._test_strip_iterable(typehints.KV[str, int], typehints.Union[str, int])
self._test_strip_iterable(typehints.Set[str], str)
self._test_strip_iterable(typehints.FrozenSet[str], str)
self._test_strip_iterable_fail(typehints.Union[str, int])
self._test_strip_iterable_fail(typehints.Optional[str])
self._test_strip_iterable_fail(
typehints.WindowedValue[str]) # type: ignore[misc]
self._test_strip_iterable_fail(typehints.Dict[str, int])
def test_make_traceback(self):
origin = ''.join(
decorators.IOTypeHints.empty().with_input_types(str).origin)
self.assertRegex(origin, __name__)
self.assertNotRegex(origin, r'\b_make_traceback')
def test_origin(self):
th = decorators.IOTypeHints.empty()
self.assertListEqual([], th.origin)
th = th.with_input_types(str)
self.assertRegex(th.debug_str(), r'with_input_types')
th = th.with_output_types(str)
self.assertRegex(th.debug_str(), r'(?s)with_output_types.*with_input_types')
th = decorators.IOTypeHints.empty().with_output_types(str)
th2 = decorators.IOTypeHints.empty().with_input_types(int)
th = th.with_defaults(th2)
self.assertRegex(th.debug_str(), r'(?s)based on:.*\'str\'.*and:.*\'int\'')
def test_with_defaults_noop_does_not_grow_origin(self):
th = decorators.IOTypeHints.empty()
expected_id = id(th)
th = th.with_defaults(None)
self.assertEqual(expected_id, id(th))
th = th.with_defaults(decorators.IOTypeHints.empty())
self.assertEqual(expected_id, id(th))
th = th.with_input_types(str)
expected_id = id(th)
th = th.with_defaults(th)
self.assertEqual(expected_id, id(th))
th2 = th.with_output_types(int)
th = th.with_defaults(th2)
self.assertNotEqual(expected_id, id(th))
def test_from_callable(self):
def fn(
a: int,
b: str = '',
*args: Tuple[T],
foo: List[int],
**kwargs: Dict[str, str]) -> Tuple[Any, ...]:
return a, b, args, foo, kwargs
th = decorators.IOTypeHints.from_callable(fn)
self.assertEqual(
th.input_types, ((int, str, Tuple[T]), {
'foo': List[int], 'kwargs': Dict[str, str]
}))
self.assertEqual(th.output_types, ((Tuple[Any, ...], ), {}))
def test_from_callable_partial_annotations(self):
def fn(a: int, b=None, *args, foo: List[int], **kwargs):
return a, b, args, foo, kwargs
th = decorators.IOTypeHints.from_callable(fn)
self.assertEqual(
th.input_types,
((int, Any, Tuple[Any, ...]), {
'foo': List[int], 'kwargs': Dict[Any, Any]
}))
self.assertEqual(th.output_types, ((Any, ), {}))
def test_from_callable_class(self):
class Class(object):
def __init__(self, unused_arg: int):
pass
th = decorators.IOTypeHints.from_callable(Class)
self.assertEqual(th.input_types, ((int, ), {}))
self.assertEqual(th.output_types, ((Class, ), {}))
def test_from_callable_method(self):
class Class(object):
def method(self, arg: T = None) -> None:
pass
th = decorators.IOTypeHints.from_callable(Class.method)
self.assertEqual(th.input_types, ((Any, T), {}))
self.assertEqual(th.output_types, ((None, ), {}))
th = decorators.IOTypeHints.from_callable(Class().method)
self.assertEqual(th.input_types, ((T, ), {}))
self.assertEqual(th.output_types, ((None, ), {}))
def test_from_callable_convert_to_beam_types(self):
def fn(
a: typing.List[int],
b: str = '',
*args: typing.Tuple[T_typing],
foo: typing.List[int],
**kwargs: typing.Dict[str, str]) -> typing.Tuple[typing.Any, ...]:
return a, b, args, foo, kwargs
th = decorators.IOTypeHints.from_callable(fn)
self.assertEqual(
th.input_types,
((List[int], str, Tuple[T]), {
'foo': List[int], 'kwargs': Dict[str, str]
}))
self.assertEqual(th.output_types, ((Tuple[Any, ...], ), {}))
def test_from_callable_partial(self):
def fn(a: int) -> int:
return a
# functools.partial objects don't have __name__ attributes by default.
fn = functools.partial(fn, 1)
th = decorators.IOTypeHints.from_callable(fn)
self.assertRegex(th.debug_str(), r'unknown')
def test_getcallargs_forhints(self):
def fn(
a: int,
b: str = '',
*args: Tuple[T],
foo: List[int],
**kwargs: Dict[str, str]) -> Tuple[Any, ...]:
return a, b, args, foo, kwargs
callargs = decorators.getcallargs_forhints(fn, float, foo=List[str])
self.assertDictEqual(
callargs,
{
'a': float,
'b': str,
'args': Tuple[T],
'foo': List[str],
'kwargs': Dict[str, str]
})
def test_getcallargs_forhints_default_arg(self):
# Default args are not necessarily types, so they should be ignored.
def fn(a=List[int], b=None, *args, foo=(), **kwargs) -> Tuple[Any, ...]:
return a, b, args, foo, kwargs
callargs = decorators.getcallargs_forhints(fn)
self.assertDictEqual(
callargs,
{
'a': Any,
'b': Any,
'args': Tuple[Any, ...],
'foo': Any,
'kwargs': Dict[Any, Any]
})
def test_getcallargs_forhints_missing_arg(self):
def fn(a, b=None, *args, foo, **kwargs):
return a, b, args, foo, kwargs
with self.assertRaisesRegex(decorators.TypeCheckError, "missing.*'a'"):
decorators.getcallargs_forhints(fn, foo=List[int])
with self.assertRaisesRegex(decorators.TypeCheckError, "missing.*'foo'"):
decorators.getcallargs_forhints(fn, 5)
def test_origin_annotated(self):
def annotated(e: str) -> str:
return e
t = Map(annotated)
th = t.get_type_hints()
th = th.with_input_types(str)
self.assertRegex(th.debug_str(), r'with_input_types')
th = th.with_output_types(str)
self.assertRegex(
th.debug_str(),
r'(?s)with_output_types.*with_input_types.*Map.annotated')
class WithTypeHintsTest(unittest.TestCase):
def test_get_type_hints_no_settings(self):
class Base(WithTypeHints):
pass
th = Base().get_type_hints()
self.assertEqual(th.input_types, None)
self.assertEqual(th.output_types, None)
def test_get_type_hints_class_decorators(self):
@decorators.with_input_types(int, str)
@decorators.with_output_types(int)
class Base(WithTypeHints):
pass
th = Base().get_type_hints()
self.assertEqual(th.input_types, ((int, str), {}))
self.assertEqual(th.output_types, ((int, ), {}))
def test_get_type_hints_class_defaults(self):
class Base(WithTypeHints):
def default_type_hints(self):
return decorators.IOTypeHints(
input_types=((int, str), {}), output_types=((int, ), {}), origin=[])
th = Base().get_type_hints()
self.assertEqual(th.input_types, ((int, str), {}))
self.assertEqual(th.output_types, ((int, ), {}))
def test_get_type_hints_precedence_defaults_over_decorators(self):
@decorators.with_input_types(int)
@decorators.with_output_types(str)
class Base(WithTypeHints):
def default_type_hints(self):
return decorators.IOTypeHints(
input_types=((float, ), {}), output_types=None, origin=[])
th = Base().get_type_hints()
self.assertEqual(th.input_types, ((float, ), {}))
self.assertEqual(th.output_types, ((str, ), {}))
def test_get_type_hints_precedence_instance_over_defaults(self):
class Base(WithTypeHints):
def default_type_hints(self):
return decorators.IOTypeHints(
input_types=((float, ), {}), output_types=((str, ), {}), origin=[])
th = Base().with_input_types(int).get_type_hints()
self.assertEqual(th.input_types, ((int, ), {}))
self.assertEqual(th.output_types, ((str, ), {}))
def test_inherits_does_not_modify(self):
# See BEAM-8629.
@decorators.with_output_types(int)
class Subclass(WithTypeHints):
def __init__(self):
pass # intentionally avoiding super call
# These should be equal, but not the same object lest mutating the instance
# mutates the class.
self.assertIsNot(
Subclass()._get_or_create_type_hints(), Subclass._type_hints)
self.assertEqual(Subclass().get_type_hints(), Subclass._type_hints)
self.assertNotEqual(
Subclass().with_input_types(str)._type_hints, Subclass._type_hints)
class DecoratorsTest(unittest.TestCase):
def tearDown(self):
decorators._disable_from_callable = False
def test_disable_type_annotations(self):
self.assertFalse(decorators._disable_from_callable)
decorators.disable_type_annotations()
self.assertTrue(decorators._disable_from_callable)
def test_no_annotations_on_same_function(self):
def fn(a: int) -> int:
return a
with self.assertRaisesRegex(TypeCheckError,
r'requires .*int.* but got .*str'):
_ = ['a', 'b', 'c'] | Map(fn)
# Same pipeline doesn't raise without annotations on fn.
fn = decorators.no_annotations(fn)
_ = ['a', 'b', 'c'] | Map(fn)
def test_no_annotations_on_diff_function(self):
def fn(a: int) -> int:
return a
_ = [1, 2, 3] | Map(fn) # Doesn't raise - correct types.
with self.assertRaisesRegex(TypeCheckError,
r'requires .*int.* but got .*str'):
_ = ['a', 'b', 'c'] | Map(fn)
@decorators.no_annotations
def fn2(a: int) -> int:
return a
_ = ['a', 'b', 'c'] | Map(fn2) # Doesn't raise - no input type hints.
if __name__ == '__main__':
unittest.main()