blob: 151347fb74c4a5b41996d49247261c284ce2db1d [file]
import polars as pl
import pytest
from polars.testing import assert_frame_equal
from hamilton import driver, node
from hamilton.function_modifiers.base import NodeInjector
from hamilton.plugins.h_polars import with_columns
from .resources import with_columns_end_to_end
def test_create_column_nodes_pass_dataframe():
def dummy_fn_with_columns(col_1: pl.Series) -> pl.Series:
return col_1 + 100
def target_fn(some_var: int, upstream_df: pl.DataFrame) -> pl.DataFrame:
return upstream_df
dummy_node = node.Node.from_fn(target_fn)
decorator = with_columns(
dummy_fn_with_columns, on_input="upstream_df", select=["dummy_fn_with_columns"]
)
injectable_params = NodeInjector.find_injectable_params([dummy_node])
inject_parameter, initial_nodes = decorator.get_initial_nodes(
fn=target_fn, params=injectable_params
)
assert inject_parameter == "upstream_df"
assert len(initial_nodes) == 0
def test_create_column_nodes_extract_single_columns():
def dummy_fn_with_columns(col_1: pl.Series) -> pl.Series:
return col_1 + 100
def dummy_df() -> pl.DataFrame:
return pl.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame:
return upstream_df
dummy_node = node.Node.from_fn(target_fn)
decorator = with_columns(
dummy_fn_with_columns, columns_to_pass=["col_1"], select=["dummy_fn_with_columns"]
)
injectable_params = NodeInjector.find_injectable_params([dummy_node])
inject_parameter, initial_nodes = decorator.get_initial_nodes(
fn=target_fn, params=injectable_params
)
assert inject_parameter == "upstream_df"
assert len(initial_nodes) == 1
assert initial_nodes[0].name == "col_1"
assert initial_nodes[0].type == pl.Series
pl.testing.assert_series_equal(
initial_nodes[0].callable(upstream_df=dummy_df()),
pl.Series([1, 2, 3, 4]),
check_names=False,
)
def test_create_column_nodes_extract_multiple_columns():
def dummy_fn_with_columns(col_1: pl.Series) -> pl.Series:
return col_1 + 100
def dummy_df() -> pl.DataFrame:
return pl.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame:
return upstream_df
dummy_node = node.Node.from_fn(target_fn)
decorator = with_columns(
dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=["dummy_fn_with_columns"]
)
injectable_params = NodeInjector.find_injectable_params([dummy_node])
inject_parameter, initial_nodes = decorator.get_initial_nodes(
fn=target_fn, params=injectable_params
)
assert inject_parameter == "upstream_df"
assert len(initial_nodes) == 2
assert initial_nodes[0].name == "col_1"
assert initial_nodes[1].name == "col_2"
assert initial_nodes[0].type == pl.Series
assert initial_nodes[1].type == pl.Series
pl.testing.assert_series_equal(
initial_nodes[0].callable(upstream_df=dummy_df()),
pl.Series([1, 2, 3, 4]),
check_names=False,
)
pl.testing.assert_series_equal(
initial_nodes[1].callable(upstream_df=dummy_df()),
pl.Series([11, 12, 13, 14]),
check_names=False,
)
def test_no_matching_select_column_error():
def dummy_fn_with_columns(col_1: pl.Series) -> pl.Series:
return col_1 + 100
def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame:
return upstream_df
dummy_node = node.Node.from_fn(target_fn)
select = "wrong_column"
decorator = with_columns(
dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=select
)
injectable_params = NodeInjector.find_injectable_params([dummy_node])
with pytest.raises(ValueError):
decorator.inject_nodes(injectable_params, {}, fn=target_fn)
def test_append_into_original_df():
def dummy_fn_with_columns(col_1: pl.Series) -> pl.Series:
return col_1 + 100
def dummy_df() -> pl.DataFrame:
return pl.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame:
return upstream_df
decorator = with_columns(
dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=["dummy_fn_with_columns"]
)
output_nodes, _ = decorator.chain_subdag_nodes(
fn=target_fn, inject_parameter="upstream_df", generated_nodes=[]
)
merge_node = output_nodes[-1]
output_df = merge_node.callable(
upstream_df=dummy_df(),
dummy_fn_with_columns=dummy_fn_with_columns(col_1=pl.Series([1, 2, 3, 4])),
)
assert merge_node.name == "__append"
assert merge_node.type == pl.DataFrame
pl.testing.assert_series_equal(output_df["col_1"], pl.Series([1, 2, 3, 4]), check_names=False)
pl.testing.assert_series_equal(
output_df["col_2"], pl.Series([11, 12, 13, 14]), check_names=False
)
pl.testing.assert_series_equal(
output_df["dummy_fn_with_columns"], pl.Series([101, 102, 103, 104]), check_names=False
)
def test_override_original_column_in_df():
def dummy_df() -> pl.DataFrame:
return pl.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame:
return upstream_df
def col_1() -> pl.Series:
return pl.col("col_1") * 100
decorator = with_columns(col_1, on_input="upstream_df", select=["col_1"])
output_nodes, _ = decorator.chain_subdag_nodes(
fn=target_fn, inject_parameter="upstream_df", generated_nodes=[]
)
merge_node = output_nodes[-1]
output_df = merge_node.callable(upstream_df=dummy_df(), col_1=col_1())
assert merge_node.name == "__append"
assert merge_node.type == pl.DataFrame
pl.testing.assert_series_equal(
output_df["col_1"], pl.Series([100, 200, 300, 400]), check_names=False
)
pl.testing.assert_series_equal(
output_df["col_2"], pl.Series([11, 12, 13, 14]), check_names=False
)
def test_assign_custom_namespace_with_columns():
def dummy_fn_with_columns(col_1: pl.Series) -> pl.Series:
return col_1 + 100
def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame:
return upstream_df
dummy_node = node.Node.from_fn(target_fn)
decorator = with_columns(
dummy_fn_with_columns,
columns_to_pass=["col_1", "col_2"],
select=["dummy_fn_with_columns"],
namespace="dummy_namespace",
)
nodes_ = decorator.transform_dag([dummy_node], {}, target_fn)
assert nodes_[0].name == "target_fn"
assert nodes_[1].name == "dummy_namespace.dummy_fn_with_columns"
assert nodes_[2].name == "dummy_namespace.col_1"
assert nodes_[3].name == "dummy_namespace.__append"
def test_end_to_end_with_columns_automatic_extract():
config_5 = {
"factor": 5,
}
dr = driver.Builder().with_modules(with_columns_end_to_end).with_config(config_5).build()
result = dr.execute(final_vars=["final_df"], inputs={"user_factor": 1000})["final_df"]
expected_df = pl.DataFrame(
{
"col_1": [1, 2, 3, 4],
"col_2": [11, 12, 13, 14],
"col_3": [1, 1, 1, 1],
"subtract_1_from_2": [10, 10, 10, 10],
"multiply_3": [5, 5, 5, 5],
"add_1_by_user_adjustment_factor": [1001, 1002, 1003, 1004],
"multiply_2_by_upstream_3": [33, 36, 39, 42],
}
)
pl.testing.assert_frame_equal(result, expected_df)
config_7 = {
"factor": 7,
}
dr = driver.Builder().with_modules(with_columns_end_to_end).with_config(config_7).build()
result = dr.execute(final_vars=["final_df"], inputs={"user_factor": 1000})["final_df"]
expected_df = pl.DataFrame(
{
"col_1": [1, 2, 3, 4],
"col_2": [11, 12, 13, 14],
"col_3": [1, 1, 1, 1],
"subtract_1_from_2": [10, 10, 10, 10],
"multiply_3": [7, 7, 7, 7],
"add_1_by_user_adjustment_factor": [1001, 1002, 1003, 1004],
"multiply_2_by_upstream_3": [33, 36, 39, 42],
}
)
assert_frame_equal(result, expected_df)
def test_end_to_end_with_columns_pass_dataframe():
config_5 = {
"factor": 5,
}
dr = driver.Builder().with_modules(with_columns_end_to_end).with_config(config_5).build()
result = dr.execute(final_vars=["final_df_2"])["final_df_2"]
expected_df = pl.DataFrame(
{
"col_1": [1, 2, 3, 4],
"col_2": [11, 12, 13, 14],
"col_3": [0, 2, 4, 6],
"multiply_3": [0, 10, 20, 30],
}
)
assert_frame_equal(result, expected_df)