blob: 7cf26f5dfd23080b46685af69397fdc10568fc11 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Unit tests for the type-hint objects and decorators."""
# pytype: skip-file
from __future__ import absolute_import
import sys
import typing
import unittest
# patches unittest.TestCase to be python3 compatible
import future.tests.base # pylint: disable=unused-import
import apache_beam as beam
from apache_beam import pvalue
from apache_beam import typehints
from apache_beam.options.pipeline_options import OptionsContext
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.typehints import WithTypeHints
from apache_beam.typehints.decorators import get_signature
# These test often construct a pipeline as value | PTransform to test side
# effects (e.g. errors).
# pylint: disable=expression-not-assigned
class MainInputTest(unittest.TestCase):
def test_bad_main_input(self):
@typehints.with_input_types(str, int)
def repeat(s, times):
return s * times
with self.assertRaises(typehints.TypeCheckError):
[1, 2, 3] | beam.Map(repeat, 3)
def test_non_function(self):
result = ['a', 'bb', 'c'] | beam.Map(str.upper)
self.assertEqual(['A', 'BB', 'C'], sorted(result))
result = ['xa', 'bbx', 'xcx'] | beam.Map(str.strip, 'x')
self.assertEqual(['a', 'bb', 'c'], sorted(result))
result = ['1', '10', '100'] | beam.Map(int)
self.assertEqual([1, 10, 100], sorted(result))
result = ['1', '10', '100'] | beam.Map(int, 16)
self.assertEqual([1, 16, 256], sorted(result))
@unittest.skipIf(
sys.version_info.major >= 3 and sys.version_info < (3, 7, 0),
'Function signatures for builtins are not available in Python 3 before '
'version 3.7.')
def test_non_function_fails(self):
with self.assertRaises(typehints.TypeCheckError):
[1, 2, 3] | beam.Map(str.upper)
def test_loose_bounds(self):
@typehints.with_input_types(typing.Union[int, float])
@typehints.with_output_types(str)
def format_number(x):
return '%g' % x
result = [1, 2, 3] | beam.Map(format_number)
self.assertEqual(['1', '2', '3'], sorted(result))
def test_typed_dofn_class(self):
@typehints.with_input_types(int)
@typehints.with_output_types(str)
class MyDoFn(beam.DoFn):
def process(self, element):
return [str(element)]
result = [1, 2, 3] | beam.ParDo(MyDoFn())
self.assertEqual(['1', '2', '3'], sorted(result))
with self.assertRaisesRegex(typehints.TypeCheckError,
r'requires.*int.*got.*str'):
['a', 'b', 'c'] | beam.ParDo(MyDoFn())
with self.assertRaisesRegex(typehints.TypeCheckError,
r'requires.*int.*got.*str'):
[1, 2, 3] | (beam.ParDo(MyDoFn()) | 'again' >> beam.ParDo(MyDoFn()))
def test_typed_callable_iterable_output(self):
@typehints.with_input_types(int)
@typehints.with_output_types(typehints.Iterable[typehints.Iterable[str]])
def do_fn(element):
return [[str(element)] * 2]
result = [1, 2] | beam.ParDo(do_fn)
self.assertEqual([['1', '1'], ['2', '2']], sorted(result))
def test_typed_dofn_instance(self):
class MyDoFn(beam.DoFn):
def process(self, element):
return [str(element)]
my_do_fn = MyDoFn().with_input_types(int).with_output_types(str)
result = [1, 2, 3] | beam.ParDo(my_do_fn)
self.assertEqual(['1', '2', '3'], sorted(result))
with self.assertRaises(typehints.TypeCheckError):
['a', 'b', 'c'] | beam.ParDo(my_do_fn)
with self.assertRaises(typehints.TypeCheckError):
[1, 2, 3] | (beam.ParDo(my_do_fn) | 'again' >> beam.ParDo(my_do_fn))
def test_filter_type_hint(self):
@typehints.with_input_types(int)
def filter_fn(data):
return data % 2
self.assertEqual([1, 3], [1, 2, 3] | beam.Filter(filter_fn))
def test_partition(self):
with TestPipeline() as p:
even, odd = (p
| beam.Create([1, 2, 3])
| 'even_odd' >> beam.Partition(lambda e, _: e % 2, 2))
self.assertIsNotNone(even.element_type)
self.assertIsNotNone(odd.element_type)
res_even = (
even
| 'IdEven' >> beam.ParDo(lambda e: [e]).with_input_types(int))
res_odd = (
odd
| 'IdOdd' >> beam.ParDo(lambda e: [e]).with_input_types(int))
assert_that(res_even, equal_to([2]), label='even_check')
assert_that(res_odd, equal_to([1, 3]), label='odd_check')
def test_typed_dofn_multi_output(self):
class MyDoFn(beam.DoFn):
def process(self, element):
if element % 2:
yield beam.pvalue.TaggedOutput('odd', element)
else:
yield beam.pvalue.TaggedOutput('even', element)
p = TestPipeline()
res = (
p
| beam.Create([1, 2, 3])
| beam.ParDo(MyDoFn()).with_outputs('odd', 'even'))
self.assertIsNotNone(res[None].element_type)
self.assertIsNotNone(res['even'].element_type)
self.assertIsNotNone(res['odd'].element_type)
res_main = (
res[None]
| 'id_none' >> beam.ParDo(lambda e: [e]).with_input_types(int))
res_even = (
res['even']
| 'id_even' >> beam.ParDo(lambda e: [e]).with_input_types(int))
res_odd = (
res['odd']
| 'id_odd' >> beam.ParDo(lambda e: [e]).with_input_types(int))
assert_that(res_main, equal_to([]), label='none_check')
assert_that(res_even, equal_to([2]), label='even_check')
assert_that(res_odd, equal_to([1, 3]), label='odd_check')
p.run()
with self.assertRaises(ValueError):
_ = res['undeclared tag']
def test_typed_dofn_multi_output_no_tags(self):
class MyDoFn(beam.DoFn):
def process(self, element):
if element % 2:
yield beam.pvalue.TaggedOutput('odd', element)
else:
yield beam.pvalue.TaggedOutput('even', element)
p = TestPipeline()
res = (p | beam.Create([1, 2, 3]) | beam.ParDo(MyDoFn()).with_outputs())
self.assertIsNotNone(res[None].element_type)
self.assertIsNotNone(res['even'].element_type)
self.assertIsNotNone(res['odd'].element_type)
res_main = (
res[None]
| 'id_none' >> beam.ParDo(lambda e: [e]).with_input_types(int))
res_even = (
res['even']
| 'id_even' >> beam.ParDo(lambda e: [e]).with_input_types(int))
res_odd = (
res['odd']
| 'id_odd' >> beam.ParDo(lambda e: [e]).with_input_types(int))
assert_that(res_main, equal_to([]), label='none_check')
assert_that(res_even, equal_to([2]), label='even_check')
assert_that(res_odd, equal_to([1, 3]), label='odd_check')
p.run()
class NativeTypesTest(unittest.TestCase):
def test_good_main_input(self):
@typehints.with_input_types(typing.Tuple[str, int])
def munge(s_i):
(s, i) = s_i
return (s + 's', i * 2)
result = [('apple', 5), ('pear', 3)] | beam.Map(munge)
self.assertEqual([('apples', 10), ('pears', 6)], sorted(result))
def test_bad_main_input(self):
@typehints.with_input_types(typing.Tuple[str, str])
def munge(s_i):
(s, i) = s_i
return (s + 's', i * 2)
with self.assertRaises(typehints.TypeCheckError):
[('apple', 5), ('pear', 3)] | beam.Map(munge)
def test_bad_main_output(self):
@typehints.with_input_types(typing.Tuple[int, int])
@typehints.with_output_types(typing.Tuple[str, str])
def munge(a_b):
(a, b) = a_b
return (str(a), str(b))
with self.assertRaises(typehints.TypeCheckError):
[(5, 4), (3, 2)] | beam.Map(munge) | 'Again' >> beam.Map(munge)
class SideInputTest(unittest.TestCase):
def _run_repeat_test(self, repeat):
self._run_repeat_test_good(repeat)
self._run_repeat_test_bad(repeat)
@OptionsContext(pipeline_type_check=True)
def _run_repeat_test_good(self, repeat):
# As a positional argument.
result = ['a', 'bb', 'c'] | beam.Map(repeat, 3)
self.assertEqual(['aaa', 'bbbbbb', 'ccc'], sorted(result))
# As a keyword argument.
result = ['a', 'bb', 'c'] | beam.Map(repeat, times=3)
self.assertEqual(['aaa', 'bbbbbb', 'ccc'], sorted(result))
def _run_repeat_test_bad(self, repeat):
# Various mismatches.
with self.assertRaises(typehints.TypeCheckError):
['a', 'bb', 'c'] | beam.Map(repeat, 'z')
with self.assertRaises(typehints.TypeCheckError):
['a', 'bb', 'c'] | beam.Map(repeat, times='z')
with self.assertRaises(typehints.TypeCheckError):
['a', 'bb', 'c'] | beam.Map(repeat, 3, 4)
if all(param.default == param.empty
for param in get_signature(repeat).parameters.values()):
with self.assertRaisesRegex(typehints.TypeCheckError,
r'(takes exactly|missing a required)'):
['a', 'bb', 'c'] | beam.Map(repeat)
def test_basic_side_input_hint(self):
@typehints.with_input_types(str, int)
def repeat(s, times):
return s * times
self._run_repeat_test(repeat)
def test_keyword_side_input_hint(self):
@typehints.with_input_types(str, times=int)
def repeat(s, times):
return s * times
self._run_repeat_test(repeat)
def test_default_typed_hint(self):
@typehints.with_input_types(str, int)
def repeat(s, times=3):
return s * times
self._run_repeat_test(repeat)
def test_default_untyped_hint(self):
@typehints.with_input_types(str)
def repeat(s, times=3):
return s * times
# No type checking on default arg.
self._run_repeat_test_good(repeat)
@OptionsContext(pipeline_type_check=True)
def test_varargs_side_input_hint(self):
@typehints.with_input_types(str, int)
def repeat(s, *times):
return s * times[0]
result = ['a', 'bb', 'c'] | beam.Map(repeat, 3)
self.assertEqual(['aaa', 'bbbbbb', 'ccc'], sorted(result))
if sys.version_info >= (3, ):
with self.assertRaisesRegex(
typehints.TypeCheckError,
r'requires Tuple\[int, ...\] but got Tuple\[str, ...\]'):
['a', 'bb', 'c'] | beam.Map(repeat, 'z')
def test_var_positional_only_side_input_hint(self):
# Test that a lambda that accepts only a VAR_POSITIONAL can accept
# side-inputs.
# TODO(BEAM-8247): There's a bug with trivial_inference inferring the output
# type when side-inputs are used (their type hints are not passed). Remove
# with_output_types(...) when this bug is fixed.
result = (['a', 'b', 'c']
| beam.Map(lambda *args: args, 5).with_input_types(
int, str).with_output_types(typehints.Tuple[str, int]))
self.assertEqual([('a', 5), ('b', 5), ('c', 5)], sorted(result))
# Type hint order doesn't matter for VAR_POSITIONAL.
result = (['a', 'b', 'c']
| beam.Map(lambda *args: args, 5).with_input_types(
int, str).with_output_types(typehints.Tuple[str, int]))
self.assertEqual([('a', 5), ('b', 5), ('c', 5)], sorted(result))
if sys.version_info >= (3, ):
with self.assertRaisesRegex(
typehints.TypeCheckError,
r'requires Tuple\[Union\[int, str\], ...\] but got '
r'Tuple\[Union\[float, int\], ...\]'):
_ = [1.2] | beam.Map(lambda *_: 'a', 5).with_input_types(int, str)
def test_var_keyword_side_input_hint(self):
# Test that a lambda that accepts a VAR_KEYWORD can accept
# side-inputs.
result = (['a', 'b', 'c']
| beam.Map(lambda e, **kwargs:
(e, kwargs), kw=5).with_input_types(str, ignored=int))
self.assertEqual([('a', {
'kw': 5
}), ('b', {
'kw': 5
}), ('c', {
'kw': 5
})],
sorted(result))
if sys.version_info >= (3, ):
with self.assertRaisesRegex(
typehints.TypeCheckError,
r'requires Dict\[str, str\] but got Dict\[str, int\]'):
_ = (['a', 'b', 'c']
| beam.Map(lambda e, **_: 'a', kw=5).with_input_types(
str, ignored=str))
def test_deferred_side_inputs(self):
@typehints.with_input_types(str, int)
def repeat(s, times):
return s * times
with TestPipeline() as p:
main_input = p | beam.Create(['a', 'bb', 'c'])
side_input = p | 'side' >> beam.Create([3])
result = main_input | beam.Map(repeat, pvalue.AsSingleton(side_input))
assert_that(result, equal_to(['aaa', 'bbbbbb', 'ccc']))
bad_side_input = p | 'bad_side' >> beam.Create(['z'])
with self.assertRaises(typehints.TypeCheckError):
main_input | 'bis' >> beam.Map(repeat, pvalue.AsSingleton(bad_side_input))
def test_deferred_side_input_iterable(self):
@typehints.with_input_types(str, typing.Iterable[str])
def concat(glue, items):
return glue.join(sorted(items))
with TestPipeline() as p:
main_input = p | beam.Create(['a', 'bb', 'c'])
side_input = p | 'side' >> beam.Create(['x', 'y', 'z'])
result = main_input | beam.Map(concat, pvalue.AsIter(side_input))
assert_that(result, equal_to(['xayaz', 'xbbybbz', 'xcycz']))
bad_side_input = p | 'bad_side' >> beam.Create([1, 2, 3])
with self.assertRaises(typehints.TypeCheckError):
main_input | 'fail' >> beam.Map(concat, pvalue.AsIter(bad_side_input))
class CustomTransformTest(unittest.TestCase):
class CustomTransform(beam.PTransform):
def _extract_input_pvalues(self, pvalueish):
return pvalueish, (pvalueish['in0'], pvalueish['in1'])
def expand(self, pvalueish):
return {'out0': pvalueish['in0'], 'out1': pvalueish['in1']}
# TODO(robertwb): (typecheck) Make these the default?
def with_input_types(self, *args, **kwargs):
return WithTypeHints.with_input_types(self, *args, **kwargs)
def with_output_types(self, *args, **kwargs):
return WithTypeHints.with_output_types(self, *args, **kwargs)
test_input = {'in0': ['a', 'b', 'c'], 'in1': [1, 2, 3]}
def check_output(self, result):
self.assertEqual(['a', 'b', 'c'], sorted(result['out0']))
self.assertEqual([1, 2, 3], sorted(result['out1']))
def test_custom_transform(self):
self.check_output(self.test_input | self.CustomTransform())
def test_keyword_type_hints(self):
self.check_output(
self.test_input
| self.CustomTransform().with_input_types(in0=str, in1=int))
self.check_output(
self.test_input | self.CustomTransform().with_input_types(in0=str))
self.check_output(
self.test_input
| self.CustomTransform().with_output_types(out0=str, out1=int))
with self.assertRaises(typehints.TypeCheckError):
self.test_input | self.CustomTransform().with_input_types(in0=int)
with self.assertRaises(typehints.TypeCheckError):
self.test_input | self.CustomTransform().with_output_types(out0=int)
def test_flat_type_hint(self):
# Type hint is applied to both.
({
'in0': ['a', 'b', 'c'], 'in1': ['x', 'y', 'z']
}
| self.CustomTransform().with_input_types(str))
with self.assertRaises(typehints.TypeCheckError):
self.test_input | self.CustomTransform().with_input_types(str)
with self.assertRaises(typehints.TypeCheckError):
self.test_input | self.CustomTransform().with_input_types(int)
with self.assertRaises(typehints.TypeCheckError):
self.test_input | self.CustomTransform().with_output_types(int)
class AnnotationsTest(unittest.TestCase):
def test_pardo_wrapper_builtin_method(self):
th = beam.ParDo(str.strip).get_type_hints()
if sys.version_info < (3, 7):
self.assertEqual(th.input_types, ((str, ), {}))
else:
# Python 3.7+ has annotations for CPython builtins
# (_MethodDescriptorType).
self.assertEqual(th.input_types, ((str, typehints.Any), {}))
self.assertEqual(th.output_types, ((typehints.Any, ), {}))
def test_pardo_wrapper_builtin_type(self):
th = beam.ParDo(list).get_type_hints()
if sys.version_info < (3, 7):
self.assertEqual(
th.input_types,
((typehints.Any, typehints.decorators._ANY_VAR_POSITIONAL), {
'__unknown__keywords': typehints.decorators._ANY_VAR_KEYWORD
}))
else:
# Python 3.7+ supports signatures for builtins like 'list'.
self.assertEqual(th.input_types, ((typehints.Any, ), {}))
self.assertEqual(th.output_types, ((typehints.Any, ), {}))
def test_pardo_wrapper_builtin_func(self):
th = beam.ParDo(len).get_type_hints()
self.assertIsNone(th.input_types)
self.assertIsNone(th.output_types)
if __name__ == '__main__':
unittest.main()