blob: 34da17e894657279bacfad0c904273abd15a277b [file]
import json
import os
import uuid
from typing import Tuple
import pytest
import burr
from burr.core import Result, State, action, default, expr
from burr.tracking import LocalTrackingClient
from burr.tracking.common.models import ApplicationModel, BeginEntryModel, EndEntryModel
@action(reads=["counter", "break_at"], writes=["counter"])
def counter(state: State) -> Tuple[dict, State]:
result = {"counter": state["counter"] + 1}
if state["break_at"] == result["counter"]:
raise ValueError("Broken")
return result, state.update(**result)
def sample_application(project_name: str, log_dir: str, app_id: str, broken: bool = False):
return (
burr.core.ApplicationBuilder()
.with_state(counter=0, break_at=2 if broken else -1)
.with_actions(counter=counter, result=Result("counter"))
.with_transitions(
("counter", "counter", expr("counter < 2")), # just count to two for testing
("counter", "result", default),
)
.with_entrypoint("counter")
.with_tracker(
project_name, tracker="local", params={"storage_dir": log_dir, "app_id": app_id}
)
.build()
)
def test_application_tracks_end_to_end(tmpdir: str):
app_id = str(uuid.uuid4())
log_dir = os.path.join(tmpdir, "tracking")
project_name = "test_application_tracks_end_to_end"
app = sample_application(project_name, log_dir, app_id)
app.run(until=["result"])
results_dir = os.path.join(log_dir, project_name, app_id)
assert os.path.exists(results_dir)
assert os.path.exists(log_output := os.path.join(results_dir, LocalTrackingClient.LOG_FILENAME))
assert os.path.exists(
graph_output := os.path.join(results_dir, LocalTrackingClient.GRAPH_FILENAME)
)
with open(log_output) as f:
log_contents = [json.loads(item) for item in f.readlines()]
with open(graph_output) as f:
graph_contents = json.load(f)
assert graph_contents["type"] == "application"
app_model = ApplicationModel.parse_obj(graph_contents)
assert app_model.entrypoint == "counter"
assert app_model.actions[0].name == "counter"
assert app_model.actions[1].name == "result"
pre_run = [
BeginEntryModel.parse_obj(line) for line in log_contents if line["type"] == "begin_entry"
]
post_run = [
EndEntryModel.parse_obj(line) for line in log_contents if line["type"] == "end_entry"
]
assert len(pre_run) == 3
assert len(post_run) == 3
assert not any(item.exception for item in post_run)
def test_application_tracks_end_to_end_broken(tmpdir: str):
app_id = str(uuid.uuid4())
log_dir = os.path.join(tmpdir, "tracking")
project_name = "test_application_tracks_end_to_end"
app = sample_application(project_name, log_dir, app_id, broken=True)
with pytest.raises(ValueError):
app.run(until=["result"])
results_dir = os.path.join(log_dir, project_name, app_id)
assert os.path.exists(results_dir)
assert os.path.exists(log_output := os.path.join(results_dir, LocalTrackingClient.LOG_FILENAME))
assert os.path.exists(
graph_output := os.path.join(results_dir, LocalTrackingClient.GRAPH_FILENAME)
)
with open(log_output) as f:
log_contents = [json.loads(item) for item in f.readlines()]
with open(graph_output) as f:
graph_contents = json.load(f)
assert graph_contents["type"] == "application"
app_model = ApplicationModel.parse_obj(graph_contents)
assert app_model.entrypoint == "counter"
assert app_model.actions[0].name == "counter"
assert app_model.actions[1].name == "result"
pre_run = [
BeginEntryModel.parse_obj(line) for line in log_contents if line["type"] == "begin_entry"
]
post_run = [
EndEntryModel.parse_obj(line) for line in log_contents if line["type"] == "end_entry"
]
assert len(pre_run) == 2
assert len(post_run) == 2
assert len(post_run[-1].exception) > 0 and "Broken" in post_run[-1].exception