blob: ac2a82b6fbc6ed4567f725750bf9eae4ccdd1400 [file] [log] [blame]
"""A few tests for combining different decorators.
While this should not be necessary -- we should be able to test the decorator lifecycle functions,
it is useful to have a few tests that demonstrate that common use-cases are supported.
Note we also have some more end-to-end cases in test_layered.py"""
from typing import Dict
import pandas as pd
from hamilton.function_modifiers import base as fm_base
from hamilton.function_modifiers import extract_columns, extract_fields, subdag, tag
def test_subdag_and_extract_columns():
def foo() -> pd.Series:
return pd.Series([1, 2, 3])
def bar() -> pd.Series:
return pd.Series([1, 2, 3])
@extract_columns("foo", "bar")
@subdag(foo, bar)
def foo_bar(foo: pd.Series, bar: pd.Series) -> pd.DataFrame:
return pd.DataFrame({"foo": foo, "bar": bar})
nodes = fm_base.resolve_nodes(foo_bar, {})
nodes_by_name = {node.name: node for node in nodes}
assert sorted(nodes_by_name) == ["bar", "foo", "foo_bar", "foo_bar.bar", "foo_bar.foo"]
# The extraction columns should depend on the thing from which they are extracted
assert sorted(nodes_by_name["foo"].input_types.keys()) == ["foo_bar"]
assert sorted(nodes_by_name["bar"].input_types.keys()) == ["foo_bar"]
def test_subdag_and_extract_fields():
def foo() -> int:
return 1
def bar() -> int:
return 2
@extract_fields({"foo": int, "bar": int})
@subdag(foo, bar)
def foo_bar(foo: int, bar: pd.Series) -> Dict[str, int]:
return {"foo": foo, "bar": bar}
nodes = fm_base.resolve_nodes(foo_bar, {})
nodes_by_name = {node.name: node for node in nodes}
assert sorted(nodes_by_name) == ["bar", "foo", "foo_bar", "foo_bar.bar", "foo_bar.foo"]
# The extraction columns should depend on the thing from which they are extracted
assert sorted(nodes_by_name["foo"].input_types.keys()) == ["foo_bar"]
assert sorted(nodes_by_name["bar"].input_types.keys()) == ["foo_bar"]
def test_subdag_and_extract_fields_with_tags():
def foo() -> int:
return 1
def bar() -> int:
return 2
@tag(a="b", target_="foo")
@tag(a="c", target_="bar")
@extract_fields({"foo": int, "bar": int})
@subdag(foo, bar)
def foo_bar(foo: int, bar: pd.Series) -> Dict[str, int]:
return {"foo": foo, "bar": bar}
nodes = fm_base.resolve_nodes(foo_bar, {})
nodes_by_name = {node.name: node for node in nodes}
assert sorted(nodes_by_name) == ["bar", "foo", "foo_bar", "foo_bar.bar", "foo_bar.foo"]
# The extraction columns should depend on the thing from which they are extracted
assert sorted(nodes_by_name["foo"].input_types.keys()) == ["foo_bar"]
assert sorted(nodes_by_name["bar"].input_types.keys()) == ["foo_bar"]
assert nodes_by_name["foo"].tags["a"] == "b"
assert nodes_by_name["bar"].tags["a"] == "c"
def test_subdag_and_extract_fields_dangling_nodes():
def foo() -> int:
return 1
def bar() -> int:
return 2
@extract_fields({"foo": int, "bar": int})
@subdag(foo, bar)
def foo_bar(foo: int) -> Dict[str, int]:
return {"foo": foo, "bar": 5}
nodes = fm_base.resolve_nodes(foo_bar, {})
nodes_by_name = {node.name: node for node in nodes}
# Doing this set comparison as I'm not sure that foo_bar.bar should actually be in there
assert len({"bar", "foo", "foo_bar", "foo_bar.foo"} - set(nodes_by_name)) == 0
# The extraction columns should depend on the thing from which they are extracted
assert sorted(nodes_by_name["foo"].input_types.keys()) == ["foo_bar"]
assert sorted(nodes_by_name["bar"].input_types.keys()) == ["foo_bar"]