blob: 4572f8e497b48e66e88b8317a750e5a28e8e69fb [file]
import pytest
from burr.core.state import State, register_field_serde
def test_state_access():
state = State({"foo": "bar"})
assert state["foo"] == "bar"
def test_state_access_missing():
state = State({"foo": "bar"})
with pytest.raises(KeyError):
_ = state["baz"]
def test_state_get():
state = State({"foo": "bar"})
assert state.get("foo") == "bar"
def test_state_get_missing():
state = State({"foo": "bar"})
assert state.get("baz") is None
def test_state_get_missing_default():
state = State({"foo": "bar"})
assert state.get("baz", "qux") == "qux"
def test_state_in():
state = State({"foo": "bar"})
assert "foo" in state
assert "baz" not in state
def test_state_get_all():
state = State({"foo": "bar", "baz": "qux"})
assert state.get_all() == {"foo": "bar", "baz": "qux"}
def test_state_merge():
state = State({"foo": "bar", "baz": "qux"})
other = State({"foo": "baz", "quux": "corge"})
merged = state.merge(other)
assert merged.get_all() == {"foo": "baz", "baz": "qux", "quux": "corge"}
def test_state_subset():
state = State({"foo": "bar", "baz": "qux"})
subset = state.subset("foo")
assert subset.get_all() == {"foo": "bar"}
def test_state_append():
state = State({"foo": ["bar"]})
appended = state.append(foo="baz")
assert appended.get_all() == {"foo": ["bar", "baz"]}
def test_state_update():
state = State({"foo": "bar", "baz": "qux"})
updated = state.update(foo="baz")
assert updated.get_all() == {"foo": "baz", "baz": "qux"}
def test_state_init():
state = State({"foo": "bar", "baz": "qux"})
assert state.get_all() == {"foo": "bar", "baz": "qux"}
def test_state_wipe_delete():
state = State({"foo": "bar", "baz": "qux"})
wiped = state.wipe(delete=["foo"])
assert wiped.get_all() == {"baz": "qux"}
def test_state_wipe_keep():
state = State({"foo": "bar", "baz": "qux"})
wiped = state.wipe(keep=["foo"])
assert wiped.get_all() == {"foo": "bar"}
def test_state_append_validate_failure():
state = State({"foo": "bar"})
with pytest.raises(ValueError, match="non-appendable"):
state.append(foo="baz", bar="qux")
def test_state_increment():
state = State({"foo": 1})
incremented = state.increment(foo=2, bar=5)
assert incremented.get_all() == {"foo": 3, "bar": 5}
def test_state_increment_validate_failure():
state = State({"foo": "bar"})
with pytest.raises(ValueError, match="non-integer"):
state.increment(foo="baz", bar="qux")
def test_field_level_serde():
def my_field_serializer(value: str, **kwargs) -> dict:
serde_value = f"serialized_{value}"
return {"value": serde_value}
def my_field_deserializer(value: dict, **kwargs) -> str:
serde_value = value["value"]
return serde_value.replace("serialized_", "")
register_field_serde("my_field", my_field_serializer, my_field_deserializer)
state = State({"foo": {"hi": "world"}, "baz": "qux", "my_field": "testing 123"})
assert state.serialize() == {
"foo": {"hi": "world"},
"baz": "qux",
"my_field": {"value": "serialized_testing 123"},
}
state = State.deserialize(
{"foo": {"hi": "world"}, "baz": "qux", "my_field": {"value": "serialized_testing 123"}}
)
assert state.get_all() == {"foo": {"hi": "world"}, "baz": "qux", "my_field": "testing 123"}
def test_field_level_serde_bad_serde_function():
def my_field_serializer(value: str, **kwargs) -> str:
# bad function
serde_value = f"serialized_{value}"
return serde_value
def my_field_deserializer(value: dict, **kwargs) -> str:
serde_value = value["value"]
return serde_value.replace("serialized_", "")
register_field_serde("my_field", my_field_serializer, my_field_deserializer)
state = State({"foo": {"hi": "world"}, "baz": "qux", "my_field": "testing 123"})
with pytest.raises(ValueError):
state.serialize()
def test_register_field_serde_check_no_kwargs():
def my_field_serializer(value: str) -> dict:
serde_value = f"serialized_{value}"
return {"value": serde_value}
def my_field_deserializer(value: dict) -> str:
serde_value = value["value"]
return serde_value.replace("serialized_", "")
with pytest.raises(ValueError):
# serializer & deserializer missing kwargs
register_field_serde("my_field", my_field_serializer, my_field_deserializer)
def my_field_serializer(value: str, **kwargs) -> dict:
serde_value = f"serialized_{value}"
return {"value": serde_value}
with pytest.raises(ValueError):
# deserializer still bad
register_field_serde("my_field", my_field_serializer, my_field_deserializer)