blob: b59da72cdad34fc1a729da7f7f835cd604a6a611 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for apache_beam.typehints.trivial_inference."""
from __future__ import absolute_import
import os
import sys
import unittest
from apache_beam.typehints import trivial_inference
from apache_beam.typehints import typehints
global_int = 1
@unittest.skipIf(sys.version_info >= (3, 6, 0) and
os.environ.get('RUN_SKIPPED_PY3_TESTS') != '1',
'This test still needs to be fixed on Python 3.6. '
'See BEAM-6877')
class TrivialInferenceTest(unittest.TestCase):
def assertReturnType(self, expected, f, inputs=()):
self.assertEquals(expected, trivial_inference.infer_return_type(f, inputs))
def testIdentity(self):
self.assertReturnType(int, lambda x: x, [int])
def testIndexing(self):
self.assertReturnType(int, lambda x: x[0], [typehints.Tuple[int, str]])
self.assertReturnType(str, lambda x: x[1], [typehints.Tuple[int, str]])
self.assertReturnType(str, lambda x: x[1], [typehints.List[str]])
def testTuples(self):
self.assertReturnType(
typehints.Tuple[typehints.Tuple[()], int], lambda x: ((), x), [int])
self.assertReturnType(
typehints.Tuple[str, int, float], lambda x: (x, 0, 1.0), [str])
def testGetItem(self):
def reverse(ab):
return ab[-1], ab[0]
self.assertReturnType(
typehints.Tuple[typehints.Any, typehints.Any], reverse, [typehints.Any])
self.assertReturnType(
typehints.Tuple[int, float], reverse, [typehints.Tuple[float, int]])
self.assertReturnType(
typehints.Tuple[int, str], reverse, [typehints.Tuple[str, float, int]])
self.assertReturnType(
typehints.Tuple[int, int], reverse, [typehints.List[int]])
self.assertReturnType(
typehints.List[int], lambda v: v[::-1], [typehints.List[int]])
def testUnpack(self):
def reverse(a_b):
(a, b) = a_b
return b, a
any_tuple = typehints.Tuple[typehints.Any, typehints.Any]
self.assertReturnType(
typehints.Tuple[int, float], reverse, [typehints.Tuple[float, int]])
self.assertReturnType(
typehints.Tuple[int, int], reverse, [typehints.Tuple[int, ...]])
self.assertReturnType(
typehints.Tuple[int, int], reverse, [typehints.List[int]])
self.assertReturnType(
typehints.Tuple[typehints.Union[int, float, str],
typehints.Union[int, float, str]],
reverse, [typehints.Tuple[int, float, str]])
self.assertReturnType(any_tuple, reverse, [typehints.Any])
self.assertReturnType(typehints.Tuple[int, float],
reverse, [trivial_inference.Const((1.0, 1))])
self.assertReturnType(any_tuple,
reverse, [trivial_inference.Const((1, 2, 3))])
def testNoneReturn(self):
def func(a):
if a == 5:
return a
return None
self.assertReturnType(typehints.Union[int, type(None)], func, [int])
def testSimpleList(self):
self.assertReturnType(
typehints.List[int],
lambda xs: [1, 2],
[typehints.Tuple[int, ...]])
self.assertReturnType(
typehints.List[typehints.Any],
lambda xs: list(xs), # List is a disallowed builtin
[typehints.Tuple[int, ...]])
def testListComprehension(self):
self.assertReturnType(
typehints.List[int],
lambda xs: [x for x in xs],
[typehints.Tuple[int, ...]])
@unittest.skipIf(sys.version_info[0] == 3 and
os.environ.get('RUN_SKIPPED_PY3_TESTS') != '1',
'This test still needs to be fixed on Python 3. '
'See BEAM-6877')
def testTupleListComprehension(self):
self.assertReturnType(
typehints.List[int],
lambda xs: [x for x in xs],
[typehints.Tuple[int, int, int]])
self.assertReturnType(
typehints.List[typehints.Union[int, float]],
lambda xs: [x for x in xs],
[typehints.Tuple[int, float]])
# TODO(luke-zhu): This test fails in Python 3
self.assertReturnType(
typehints.List[typehints.Tuple[str, int]],
lambda kvs: [(kvs[0], v) for v in kvs[1]],
[typehints.Tuple[str, typehints.Iterable[int]]])
self.assertReturnType(
typehints.List[typehints.Tuple[str, typehints.Union[str, int], int]],
lambda L: [(a, a or b, b) for a, b in L],
[typehints.Iterable[typehints.Tuple[str, int]]])
def testGenerator(self):
def foo(x, y):
yield x
yield y
self.assertReturnType(typehints.Iterable[int], foo, [int, int])
self.assertReturnType(
typehints.Iterable[typehints.Union[int, float]], foo, [int, float])
def testGeneratorComprehension(self):
self.assertReturnType(
typehints.Iterable[int],
lambda xs: (x for x in xs),
[typehints.Tuple[int, ...]])
def testBinOp(self):
self.assertReturnType(int, lambda a, b: a + b, [int, int])
self.assertReturnType(
typehints.Any, lambda a, b: a + b, [int, typehints.Any])
self.assertReturnType(
typehints.List[typehints.Union[int, str]], lambda a, b: a + b,
[typehints.List[int], typehints.List[str]])
def testCall(self):
f = lambda x, *args: x
self.assertReturnType(
typehints.Tuple[int, float], lambda: (f(1), f(2.0, 3)))
# We could do better here, but this is at least correct.
self.assertReturnType(
typehints.Tuple[int, typehints.Any], lambda: (1, f(x=1.0)))
def testClosure(self):
x = 1
y = 1.0
self.assertReturnType(typehints.Tuple[int, float], lambda: (x, y))
def testGlobals(self):
self.assertReturnType(int, lambda: global_int)
def testBuiltins(self):
self.assertReturnType(int, lambda x: len(x), [typehints.Any])
def testGetAttr(self):
self.assertReturnType(
typehints.Tuple[str, typehints.Any],
lambda: (typehints.__doc__, typehints.fake))
def testMethod(self):
class A(object):
def m(self, x):
return x
self.assertReturnType(int, lambda: A().m(3))
self.assertReturnType(float, lambda: A.m(A(), 3.0))
def testAlwaysReturnsEarly(self):
def some_fn(v):
if v:
return 1
return 2
self.assertReturnType(int, some_fn)
def testDict(self):
self.assertReturnType(
typehints.Dict[typehints.Any, typehints.Any], lambda: {})
def testDictComprehension(self):
# Just ensure it doesn't crash.
fields = []
self.assertReturnType(
typehints.Any,
lambda row: {f: row[f] for f in fields}, [typehints.Any])
if __name__ == '__main__':
unittest.main()