Handles union input type parameter better
If you had a union annotation it had to be exact with the upstream dependency.
Now it does not. It can be a subset of the union. This adds code for that
and unit tests.
diff --git a/hamilton/base.py b/hamilton/base.py
index eafa966..46edb60 100644
--- a/hamilton/base.py
+++ b/hamilton/base.py
@@ -119,7 +119,7 @@
This is used when the function graph is being created and we're statically type checking the annotations
for compatibility.
- :param node_type: The type of the node.
+ :param node_type: The type of the upstream node.
:param input_type: The type of the input that would flow into the node.
:return:
"""
@@ -159,7 +159,13 @@
@staticmethod
def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type) -> bool:
- return node_type == input_type
+ if node_type == input_type:
+ return True
+ elif typing_inspect.is_union_type(input_type):
+ union_types = typing_inspect.get_args(input_type)
+ return any([SimplePythonDataFrameGraphAdapter.check_node_type_equivalence(node_type, ut)
+ for ut in union_types])
+ return False
def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any:
return node.callable(**kwargs)
diff --git a/tests/test_base.py b/tests/test_base.py
index 0fd6e68..a912ebd 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -64,6 +64,7 @@
(str, 'abc'),
(typing.Union[int, pd.Series], pd.Series([1,2,3])),
(typing.Union[int, pd.Series], 1),
+ (typing.Union[int, typing.Union[float, pd.Series]], 1.0),
], ids=[
'test-any',
'test-subclass',
@@ -77,6 +78,7 @@
'test-type-match-str',
'test-union-match-series',
'test-union-match-int',
+ 'test-union-match-nested-float',
])
def test_SimplePythonDataFrameGraphAdapter_check_input_type_match(node_type, input_value):
"""Tests check_input_type of SimplePythonDataFrameGraphAdapter"""
@@ -115,6 +117,74 @@
assert actual is False
+@pytest.mark.parametrize('node_type,input_type', [
+ (typing.Any, typing.Any),
+ (pd.Series, pd.Series),
+ (T, T),
+ (typing.List, typing.List),
+ (typing.Dict, typing.Dict),
+ (dict, dict),
+ (list, list),
+ (int, int),
+ (float, float),
+ (str, str),
+ (typing.Union[int, pd.Series], typing.Union[int, pd.Series]),
+ (pd.Series, typing.Union[int, pd.Series]),
+ (int, typing.Union[int, pd.Series]),
+ (float, typing.Union[int, typing.Union[float, pd.Series]]),
+], ids=[
+ 'test-any',
+ 'test-subclass',
+ 'test-typevar',
+ 'test-generic-list',
+ 'test-generic-dict',
+ 'test-type-match-dict',
+ 'test-type-match-list',
+ 'test-type-match-int',
+ 'test-type-match-float',
+ 'test-type-match-str',
+ 'test-union-match-exact',
+ 'test-union-match-subset-series',
+ 'test-union-match-subset-int',
+ 'test-union-match-subset-nested-float',
+])
+def test_SimplePythonDataFrameGraphAdapter_check_node_type_equivalence_match(node_type, input_type):
+ """Tests matches for check_node_type_equivalence function"""
+ adapter = base.SimplePythonDataFrameGraphAdapter()
+ actual = adapter.check_node_type_equivalence(node_type, input_type)
+ assert actual is True
+
+
+@pytest.mark.parametrize('node_type,input_type', [
+ (typing.Union[int, pd.Series], typing.Any),
+ (pd.DataFrame, pd.Series),
+ (typing.List, list),
+ (typing.Dict, dict),
+ (dict, list),
+ (list, dict),
+ (int, float),
+ (float, int),
+ (str, int),
+ (typing.Union[int, pd.Series], float),
+], ids=[
+ 'test-any-mismatch',
+ 'test-class-mismatch',
+ 'test-generic-mismatch-list',
+ 'test-generic-mismatch-dict',
+ 'test-type-mistmatch-dict',
+ 'test-type-mismatch-list',
+ 'test-type-mismatch-int',
+ 'test-type-mismatch-float',
+ 'test-type-mismatch-str',
+ 'test-union-mismatch-float',
+])
+def test_SimplePythonDataFrameGraphAdapter_check_node_type_equivalence_mismatch(node_type, input_type):
+ """Tests mismatches for check_node_type_equivalence function"""
+ adapter = base.SimplePythonDataFrameGraphAdapter()
+ actual = adapter.check_node_type_equivalence(node_type, input_type)
+ assert actual is False
+
+
@pytest.mark.parametrize('outputs,expected_result', [
({'a': pd.Series([1, 2, 3])},
pd.DataFrame({'a': pd.Series([1, 2, 3])})),