| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """Trivial type inference for simple functions. |
| |
| For internal use only; no backwards-compatibility guarantees. |
| """ |
| from __future__ import absolute_import |
| from __future__ import print_function |
| |
| import collections |
| import dis |
| import inspect |
| import pprint |
| import sys |
| import traceback |
| import types |
| from builtins import object |
| from builtins import zip |
| from functools import reduce |
| |
| from apache_beam.typehints import Any |
| from apache_beam.typehints import typehints |
| |
| # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports |
| try: # Python 2 |
| import __builtin__ as builtins |
| except ImportError: # Python 3 |
| import builtins |
| # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports |
| |
| |
| class TypeInferenceError(ValueError): |
| """Error to raise when type inference failed.""" |
| pass |
| |
| |
| def instance_to_type(o): |
| """Given a Python object o, return the corresponding type hint. |
| """ |
| t = type(o) |
| if o is None: |
| return type(None) |
| elif t not in typehints.DISALLOWED_PRIMITIVE_TYPES: |
| # pylint: disable=deprecated-types-field |
| if sys.version_info[0] == 2 and t == types.InstanceType: |
| return o.__class__ |
| if t == BoundMethod: |
| return types.MethodType |
| return t |
| elif t == tuple: |
| return typehints.Tuple[[instance_to_type(item) for item in o]] |
| elif t == list: |
| return typehints.List[ |
| typehints.Union[[instance_to_type(item) for item in o]] |
| ] |
| elif t == set: |
| return typehints.Set[ |
| typehints.Union[[instance_to_type(item) for item in o]] |
| ] |
| elif t == dict: |
| return typehints.Dict[ |
| typehints.Union[[instance_to_type(k) for k, v in o.items()]], |
| typehints.Union[[instance_to_type(v) for k, v in o.items()]], |
| ] |
| else: |
| raise TypeInferenceError('Unknown forbidden type: %s' % t) |
| |
| |
| def union_list(xs, ys): |
| assert len(xs) == len(ys) |
| return [union(x, y) for x, y in zip(xs, ys)] |
| |
| |
| class Const(object): |
| |
| def __init__(self, value): |
| self.value = value |
| self.type = instance_to_type(value) |
| |
| def __eq__(self, other): |
| return isinstance(other, Const) and self.value == other.value |
| |
| def __ne__(self, other): |
| # TODO(BEAM-5949): Needed for Python 2 compatibility. |
| return not self == other |
| |
| def __hash__(self): |
| return hash(self.value) |
| |
| def __repr__(self): |
| return 'Const[%s]' % str(self.value)[:100] |
| |
| @staticmethod |
| def unwrap(x): |
| if isinstance(x, Const): |
| return x.type |
| return x |
| |
| @staticmethod |
| def unwrap_all(xs): |
| return [Const.unwrap(x) for x in xs] |
| |
| |
| class FrameState(object): |
| """Stores the state of the frame at a particular point of execution. |
| """ |
| |
| def __init__(self, f, local_vars=None, stack=()): |
| self.f = f |
| self.co = f.__code__ |
| self.vars = list(local_vars) |
| self.stack = list(stack) |
| |
| def __eq__(self, other): |
| return isinstance(other, FrameState) and self.__dict__ == other.__dict__ |
| |
| def __ne__(self, other): |
| # TODO(BEAM-5949): Needed for Python 2 compatibility. |
| return not self == other |
| |
| def __hash__(self): |
| return hash(tuple(sorted(self.__dict__.items()))) |
| |
| def copy(self): |
| return FrameState(self.f, self.vars, self.stack) |
| |
| def const_type(self, i): |
| return Const(self.co.co_consts[i]) |
| |
| def get_closure(self, i): |
| num_cellvars = len(self.co.co_cellvars) |
| if i < num_cellvars: |
| return self.vars[i] |
| else: |
| return self.f.__closure__[i - num_cellvars].cell_contents |
| |
| def closure_type(self, i): |
| """Returns a TypeConstraint or Const.""" |
| val = self.get_closure(i) |
| if isinstance(val, typehints.TypeConstraint): |
| return val |
| else: |
| return Const(val) |
| |
| def get_global(self, i): |
| name = self.get_name(i) |
| if name in self.f.__globals__: |
| return Const(self.f.__globals__[name]) |
| if name in builtins.__dict__: |
| return Const(builtins.__dict__[name]) |
| return Any |
| |
| def get_name(self, i): |
| return self.co.co_names[i] |
| |
| def __repr__(self): |
| return 'Stack: %s Vars: %s' % (self.stack, self.vars) |
| |
| def __or__(self, other): |
| if self is None: |
| return other.copy() |
| elif other is None: |
| return self.copy() |
| return FrameState(self.f, union_list(self.vars, other.vars), union_list( |
| self.stack, other.stack)) |
| |
| def __ror__(self, left): |
| return self | left |
| |
| |
| def union(a, b): |
| """Returns the union of two types or Const values. |
| """ |
| if a == b: |
| return a |
| elif not a: |
| return b |
| elif not b: |
| return a |
| a = Const.unwrap(a) |
| b = Const.unwrap(b) |
| # TODO(robertwb): Work this into the Union code in a more generic way. |
| if type(a) == type(b) and element_type(a) == typehints.Union[()]: |
| return b |
| elif type(a) == type(b) and element_type(b) == typehints.Union[()]: |
| return a |
| return typehints.Union[a, b] |
| |
| |
| def finalize_hints(type_hint): |
| """Sets type hint for empty data structures to Any.""" |
| def visitor(tc, unused_arg): |
| if isinstance(tc, typehints.DictConstraint): |
| empty_union = typehints.Union[()] |
| if tc.key_type == empty_union: |
| tc.key_type = Any |
| if tc.value_type == empty_union: |
| tc.value_type = Any |
| |
| if isinstance(type_hint, typehints.TypeConstraint): |
| type_hint.visit(visitor, None) |
| |
| |
| def element_type(hint): |
| """Returns the element type of a composite type. |
| """ |
| hint = Const.unwrap(hint) |
| if isinstance(hint, typehints.SequenceTypeConstraint): |
| return hint.inner_type |
| elif isinstance(hint, typehints.TupleHint.TupleConstraint): |
| return typehints.Union[hint.tuple_types] |
| return Any |
| |
| |
| def key_value_types(kv_type): |
| """Returns the key and value type of a KV type. |
| """ |
| # TODO(robertwb): Unions of tuples, etc. |
| # TODO(robertwb): Assert? |
| if (isinstance(kv_type, typehints.TupleHint.TupleConstraint) |
| and len(kv_type.tuple_types) == 2): |
| return kv_type.tuple_types |
| return Any, Any |
| |
| |
| known_return_types = {len: int, hash: int,} |
| |
| |
| class BoundMethod(object): |
| """Used to create a bound method when we only know the type of the instance. |
| """ |
| |
| def __init__(self, func, type): |
| """Instantiates a bound method object. |
| |
| Args: |
| func (types.FunctionType): The method's underlying function |
| type (type): The class of the method. |
| """ |
| self.func = func |
| self.type = type |
| |
| |
| def hashable(c): |
| try: |
| hash(c) |
| return True |
| except TypeError: |
| return False |
| |
| |
| def infer_return_type(c, input_types, debug=False, depth=5): |
| """Analyses a callable to deduce its return type. |
| |
| Args: |
| c: A Python callable to infer the return type of. |
| input_types: A sequence of inputs corresponding to the input types. |
| debug: Whether to print verbose debugging information. |
| depth: Maximum inspection depth during type inference. |
| |
| Returns: |
| A TypeConstraint that that the return value of this function will (likely) |
| satisfy given the specified inputs. |
| """ |
| try: |
| if hashable(c) and c in known_return_types: |
| return known_return_types[c] |
| elif isinstance(c, types.FunctionType): |
| return infer_return_type_func(c, input_types, debug, depth) |
| elif isinstance(c, types.MethodType): |
| if c.__self__ is not None: |
| input_types = [Const(c.__self__)] + input_types |
| return infer_return_type_func(c.__func__, input_types, debug, depth) |
| elif isinstance(c, BoundMethod): |
| input_types = [c.type] + input_types |
| return infer_return_type_func(c.func, input_types, debug, depth) |
| elif inspect.isclass(c): |
| if c in typehints.DISALLOWED_PRIMITIVE_TYPES: |
| return { |
| list: typehints.List[Any], |
| set: typehints.Set[Any], |
| tuple: typehints.Tuple[Any, ...], |
| dict: typehints.Dict[Any, Any] |
| }[c] |
| return c |
| else: |
| return Any |
| except TypeInferenceError: |
| if debug: |
| traceback.print_exc() |
| return Any |
| except Exception: |
| if debug: |
| sys.stdout.flush() |
| raise |
| else: |
| return Any |
| |
| |
| def infer_return_type_func(f, input_types, debug=False, depth=0): |
| """Analyses a function to deduce its return type. |
| |
| Args: |
| f: A Python function object to infer the return type of. |
| input_types: A sequence of inputs corresponding to the input types. |
| debug: Whether to print verbose debugging information. |
| depth: Maximum inspection depth during type inference. |
| |
| Returns: |
| A TypeConstraint that that the return value of this function will (likely) |
| satisfy given the specified inputs. |
| |
| Raises: |
| TypeInferenceError: if no type can be inferred. |
| """ |
| if debug: |
| print() |
| print(f, id(f), input_types) |
| dis.dis(f) |
| from . import opcodes |
| simple_ops = dict((k.upper(), v) for k, v in opcodes.__dict__.items()) |
| |
| co = f.__code__ |
| code = co.co_code |
| end = len(code) |
| pc = 0 |
| extended_arg = 0 # Python 2 only. |
| free = None |
| |
| yields = set() |
| returns = set() |
| # TODO(robertwb): Default args via inspect module. |
| local_vars = list(input_types) + [typehints.Union[()]] * (len(co.co_varnames) |
| - len(input_types)) |
| state = FrameState(f, local_vars) |
| states = collections.defaultdict(lambda: None) |
| jumps = collections.defaultdict(int) |
| |
| # In Python 3, use dis library functions to disassemble bytecode and handle |
| # EXTENDED_ARGs. |
| is_py3 = sys.version_info[0] == 3 |
| if is_py3: |
| ofs_table = {} # offset -> instruction |
| for instruction in dis.get_instructions(f): |
| ofs_table[instruction.offset] = instruction |
| |
| # Python 2 - 3.5: 1 byte opcode + optional 2 byte arg (1 or 3 bytes). |
| # Python 3.6+: 1 byte opcode + 1 byte arg (2 bytes, arg may be ignored). |
| if sys.version_info >= (3, 6): |
| inst_size = 2 |
| opt_arg_size = 0 |
| else: |
| inst_size = 1 |
| opt_arg_size = 2 |
| |
| last_pc = -1 |
| while pc < end: # pylint: disable=too-many-nested-blocks |
| start = pc |
| if is_py3: |
| instruction = ofs_table[pc] |
| op = instruction.opcode |
| else: |
| op = ord(code[pc]) |
| if debug: |
| print('-->' if pc == last_pc else ' ', end=' ') |
| print(repr(pc).rjust(4), end=' ') |
| print(dis.opname[op].ljust(20), end=' ') |
| |
| pc += inst_size |
| if op >= dis.HAVE_ARGUMENT: |
| if is_py3: |
| arg = instruction.arg |
| else: |
| arg = ord(code[pc]) + ord(code[pc + 1]) * 256 + extended_arg |
| extended_arg = 0 |
| pc += opt_arg_size |
| if op == dis.EXTENDED_ARG: |
| extended_arg = arg * 65536 |
| if debug: |
| print(str(arg).rjust(5), end=' ') |
| if op in dis.hasconst: |
| print('(' + repr(co.co_consts[arg]) + ')', end=' ') |
| elif op in dis.hasname: |
| print('(' + co.co_names[arg] + ')', end=' ') |
| elif op in dis.hasjrel: |
| print('(to ' + repr(pc + arg) + ')', end=' ') |
| elif op in dis.haslocal: |
| print('(' + co.co_varnames[arg] + ')', end=' ') |
| elif op in dis.hascompare: |
| print('(' + dis.cmp_op[arg] + ')', end=' ') |
| elif op in dis.hasfree: |
| if free is None: |
| free = co.co_cellvars + co.co_freevars |
| print('(' + free[arg] + ')', end=' ') |
| |
| # Actually emulate the op. |
| if state is None and states[start] is None: |
| # No control reaches here (yet). |
| if debug: |
| print() |
| continue |
| state |= states[start] |
| |
| opname = dis.opname[op] |
| jmp = jmp_state = None |
| if opname.startswith('CALL_FUNCTION'): |
| if sys.version_info < (3, 6): |
| # Each keyword takes up two arguments on the stack (name and value). |
| standard_args = (arg & 0xFF) + 2 * (arg >> 8) |
| var_args = 'VAR' in opname |
| kw_args = 'KW' in opname |
| pop_count = standard_args + var_args + kw_args + 1 |
| if depth <= 0: |
| return_type = Any |
| elif arg >> 8: |
| # TODO(robertwb): Handle this case. |
| return_type = Any |
| elif isinstance(state.stack[-pop_count], Const): |
| # TODO(robertwb): Handle this better. |
| if var_args or kw_args: |
| state.stack[-1] = Any |
| state.stack[-var_args - kw_args] = Any |
| return_type = infer_return_type(state.stack[-pop_count].value, |
| state.stack[1 - pop_count:], |
| debug=debug, |
| depth=depth - 1) |
| else: |
| return_type = Any |
| state.stack[-pop_count:] = [return_type] |
| else: # Python 3.6+ |
| if opname == 'CALL_FUNCTION': |
| pop_count = arg + 1 |
| if depth <= 0: |
| return_type = Any |
| else: |
| return_type = infer_return_type(state.stack[-pop_count].value, |
| state.stack[1 - pop_count:], |
| debug=debug, |
| depth=depth - 1) |
| elif opname == 'CALL_FUNCTION_KW': |
| # TODO(udim): Handle keyword arguments. Requires passing them by name |
| # to infer_return_type. |
| pop_count = arg + 2 |
| return_type = Any |
| elif opname == 'CALL_FUNCTION_EX': |
| # stack[-has_kwargs]: Map of keyword args. |
| # stack[-1 - has_kwargs]: Iterable of positional args. |
| # stack[-2 - has_kwargs]: Function to call. |
| has_kwargs = arg & 1 # type: int |
| pop_count = has_kwargs + 2 |
| if has_kwargs: |
| # TODO(udim): Unimplemented. Requires same functionality as a |
| # CALL_FUNCTION_KW implementation. |
| return_type = Any |
| else: |
| args = state.stack[-1] |
| _callable = state.stack[-2] |
| if isinstance(args, typehints.ListConstraint): |
| # Case where there's a single var_arg argument. |
| args = [args] |
| elif isinstance(args, typehints.TupleConstraint): |
| args = list(args._inner_types()) |
| return_type = infer_return_type(_callable.value, |
| args, |
| debug=debug, |
| depth=depth - 1) |
| else: |
| raise TypeInferenceError('unable to handle %s' % opname) |
| state.stack[-pop_count:] = [return_type] |
| elif opname == 'CALL_METHOD': |
| pop_count = 1 + arg |
| # LOAD_METHOD will return a non-Const (Any) if loading from an Any. |
| if isinstance(state.stack[-pop_count], Const) and depth > 0: |
| return_type = infer_return_type(state.stack[-pop_count].value, |
| state.stack[1 - pop_count:], |
| debug=debug, |
| depth=depth - 1) |
| else: |
| return_type = typehints.Any |
| state.stack[-pop_count:] = [return_type] |
| elif opname in simple_ops: |
| if debug: |
| print("Executing simple op " + opname) |
| simple_ops[opname](state, arg) |
| elif opname == 'RETURN_VALUE': |
| returns.add(state.stack[-1]) |
| state = None |
| elif opname == 'YIELD_VALUE': |
| yields.add(state.stack[-1]) |
| elif opname == 'JUMP_FORWARD': |
| jmp = pc + arg |
| jmp_state = state |
| state = None |
| elif opname == 'JUMP_ABSOLUTE': |
| jmp = arg |
| jmp_state = state |
| state = None |
| elif opname in ('POP_JUMP_IF_TRUE', 'POP_JUMP_IF_FALSE'): |
| state.stack.pop() |
| jmp = arg |
| jmp_state = state.copy() |
| elif opname in ('JUMP_IF_TRUE_OR_POP', 'JUMP_IF_FALSE_OR_POP'): |
| jmp = arg |
| jmp_state = state.copy() |
| state.stack.pop() |
| elif opname == 'FOR_ITER': |
| jmp = pc + arg |
| jmp_state = state.copy() |
| jmp_state.stack.pop() |
| state.stack.append(element_type(state.stack[-1])) |
| else: |
| raise TypeInferenceError('unable to handle %s' % opname) |
| |
| if jmp is not None: |
| # TODO(robertwb): Is this guaranteed to converge? |
| new_state = states[jmp] | jmp_state |
| if jmp < pc and new_state != states[jmp] and jumps[pc] < 5: |
| jumps[pc] += 1 |
| pc = jmp |
| states[jmp] = new_state |
| |
| if debug: |
| print() |
| print(state) |
| pprint.pprint(dict(item for item in states.items() if item[1])) |
| |
| if yields: |
| result = typehints.Iterable[reduce(union, Const.unwrap_all(yields))] |
| else: |
| result = reduce(union, Const.unwrap_all(returns)) |
| finalize_hints(result) |
| |
| if debug: |
| print(f, id(f), input_types, '->', result) |
| return result |