blob: 5e58c66b50c5a7e45b2c5023c689c5b074d91fef [file]
import burr.core.application
from burr.core import Action, Condition, State, default
class ProcessDataAction(Action):
@property
def reads(self) -> list[str]:
return ["data_path"]
def run(self, state: State) -> dict:
pass
@property
def writes(self) -> list[str]:
return ["training_data", "evaluation_data"]
def update(self, result: dict, state: State) -> State:
return state.update(training_data=result["training_data"])
class TrainModel(Action):
@property
def reads(self) -> list[str]:
return ["training_data", "epochs"]
def run(self, state: State) -> dict:
pass
@property
def writes(self) -> list[str]:
return ["models", "training_metrics", "epochs"]
def update(self, result: dict, state: State) -> State:
return state.update(
epochs=result["epochs"], # overwrite each epoch
).append(
models=result["model"], # append -- note this can get big if your model is big
# so you'll want to overwrite but store the conditions, or log somewhere
metrics=result["metrics"], # append the metrics
)
class ValidateModel(Action):
@property
def reads(self) -> list[str]:
return ["models", "evaluation_data"]
def run(self, state: State) -> dict:
pass
@property
def writes(self) -> list[str]:
return ["validation_metrics"]
def update(self, result: dict, state: State) -> State:
return state.append(validation_metrics=result["validation_metrics"])
class BestModel(Action):
@property
def reads(self) -> list[str]:
return ["validation_metrics", "models"]
def run(self, state: State) -> dict:
pass
@property
def writes(self) -> list[str]:
return ["best_model"]
def update(self, result: dict, state: State) -> State:
return state.update(best_model=result["best_model"])
def application(epochs: int) -> burr.core.application.Application:
return (
burr.core.ApplicationBuilder()
.with_state(
data_path="data.csv",
epochs=10,
training_data=None,
evaluation_data=None,
models=[],
training_metrics=[],
validation_metrics=[],
best_model=None,
)
.with_actions(
process_data=ProcessDataAction(),
train_model=TrainModel(),
validate_model=ValidateModel(),
best_model=BestModel(),
)
.with_transitions(
("process_data", "train_model", default),
("train_model", "validate_model", default),
("validate_model", "best_model", Condition.expr(f"epochs>{epochs}")),
("validate_model", "train_model", default),
)
.with_entrypoint("process_data")
.build()
)
if __name__ == "__main__":
app = application(100) # doing good data science is up to you...
# state, result = app.run(until=["result"])
app.visualize(output_file_path="ml_training.png", include_conditions=True, view=True)
# assert state["counter"] == 10
# print(state["counter"])