| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from typing import Callable |
| |
| import pytest |
| |
| from burr.core import Action, Condition, Result, State, default |
| from burr.core.graph import GraphBuilder, _validate_actions, _validate_transitions |
| |
| |
| # TODO -- share versions among tests, this is duplicated in test_application.py |
| class PassedInAction(Action): |
| def __init__( |
| self, |
| reads: list[str], |
| writes: list[str], |
| fn: Callable[..., dict], |
| update_fn: Callable[[dict, State], State], |
| inputs: list[str], |
| tags: list[str] = None, |
| ): |
| super(PassedInAction, self).__init__() |
| self._reads = reads |
| self._writes = writes |
| self._fn = fn |
| self._update_fn = update_fn |
| self._inputs = inputs |
| self._tags = tags |
| |
| def run(self, state: State, **run_kwargs) -> dict: |
| return self._fn(state, **run_kwargs) |
| |
| @property |
| def inputs(self) -> list[str]: |
| return self._inputs |
| |
| def update(self, result: dict, state: State) -> State: |
| return self._update_fn(result, state) |
| |
| @property |
| def reads(self) -> list[str]: |
| return self._reads |
| |
| @property |
| def writes(self) -> list[str]: |
| return self._writes |
| |
| @property |
| def tags(self) -> list[str]: |
| return self._tags or [] |
| |
| |
| def test__validate_transitions_correct(): |
| _validate_transitions( |
| [("counter", "counter", Condition.expr("count < 10")), ("counter", "result", default)], |
| {"counter", "result"}, |
| ) |
| |
| |
| def test__validate_transitions_missing_action(): |
| with pytest.raises(ValueError, match="not found"): |
| _validate_transitions( |
| [ |
| ("counter", "counter", Condition.expr("count < 10")), |
| ("counter", "result", default), |
| ], |
| {"counter"}, |
| ) |
| |
| |
| def test__validate_transitions_redundant_transition(): |
| with pytest.raises(ValueError, match="redundant"): |
| _validate_transitions( |
| [ |
| ("counter", "counter", Condition.expr("count < 10")), |
| ("counter", "result", default), |
| ("counter", "counter", default), # this is unreachable as we already have a default |
| ], |
| {"counter", "result"}, |
| ) |
| |
| |
| def test__validate_actions_valid(): |
| _validate_actions([Result("test")]) |
| |
| |
| def test__validate_actions_empty(): |
| with pytest.raises(ValueError, match="at least one"): |
| _validate_actions([]) |
| |
| |
| def test__validate_actions_duplicated(): |
| with pytest.raises(ValueError, match="duplicated"): |
| _validate_actions([Result("test"), Result("test")]) |
| |
| |
| base_counter_action = PassedInAction( |
| reads=["count"], |
| writes=["count"], |
| fn=lambda state: {"count": state.get("count", 0) + 1}, |
| update_fn=lambda result, state: state.update(**result), |
| inputs=[], |
| tags=["tag1", "tag2"], |
| ) |
| |
| |
| def test_graph_builder_builds(): |
| graph = ( |
| GraphBuilder() |
| .with_actions(counter=base_counter_action, result=Result("count")) |
| .with_transitions( |
| ("counter", "counter", Condition.expr("count < 10")), ("counter", "result") |
| ) |
| .build() |
| ) |
| assert len(graph.actions) == 2 |
| assert len(graph.transitions) == 2 |
| |
| |
| def test_graph_builder_with_graph(): |
| graph1 = ( |
| GraphBuilder() |
| .with_actions(counter=base_counter_action) |
| .with_transitions(("counter", "counter", Condition.expr("count < 10"))) |
| .build() |
| ) |
| graph2 = ( |
| GraphBuilder() |
| .with_actions(counter2=base_counter_action) |
| .with_transitions(("counter2", "counter2", Condition.expr("count < 20"))) |
| .build() |
| ) |
| graph = ( |
| GraphBuilder() |
| .with_graph(graph1) |
| .with_graph(graph2) |
| .with_actions(result=Result("count")) |
| .with_transitions( |
| ("counter", "counter2"), |
| ("counter2", "result"), |
| ) |
| ) |
| assert len(graph.actions) == 3 |
| assert len(graph.transitions) == 4 |
| |
| |
| def test_graph_builder_get_next_node(): |
| graph = ( |
| GraphBuilder() |
| .with_actions(counter=base_counter_action, result=Result("count")) |
| .with_transitions( |
| ("counter", "counter", Condition.expr("count < 10")), ("counter", "result") |
| ) |
| .build() |
| ) |
| assert len(graph.actions) == 2 |
| assert len(graph.transitions) == 2 |
| assert graph.get_next_node(None, State({"count": 0}), entrypoint="counter").name == "counter" |
| |
| |
| def test_get_actions_by_tag(): |
| action_with_tags = PassedInAction( |
| reads=["count"], |
| writes=["count"], |
| fn=lambda state: {"count": state.get("count", 0) + 1}, |
| update_fn=lambda result, state: state.update(**result), |
| inputs=[], |
| tags=["tag1", "tag2"], |
| ) |
| |
| action_with_tags_2 = PassedInAction( |
| reads=["count"], |
| writes=["count"], |
| fn=lambda state: {"count": state.get("count", 0) + 1}, |
| update_fn=lambda result, state: state.update(**result), |
| inputs=[], |
| tags=["tag1", "tag3"], |
| ) |
| graph = ( |
| GraphBuilder() |
| .with_actions(counter1=action_with_tags, counter2=action_with_tags_2) |
| .with_transitions(("counter1", "counter2")) |
| .with_transitions(("counter2", "counter1")) |
| .build() |
| ) |
| # tag1 is in both actions, tag2 is in one, tag3 is in one |
| assert len(graph.get_actions_by_tag("tag1")) == 2 |
| assert len(graph.get_actions_by_tag("tag2")) == 1 |
| assert len(graph.get_actions_by_tag("tag3")) == 1 |
| with pytest.raises(ValueError, match="not found"): |
| graph.get_actions_by_tag("tag4") |