Updates from_functions and from_molues to use the same base "compile"
function
This makes it so that we reduce code paths -- we'll eventually use
compile for everything. This also allows us to push functions up to the
user level, enabling `with_functions`
Note the following desired behavior:
Desired behavior:
- .with_modules() is the main entrypoint, users pass code that's stored in versionable file
- .allow_module_overrides() (current behavior) allows .with_modules(module_a, module_b) to have module_b override nodes from module_a
- .with_functions() allows to build a DAG from functions alone or in combination with .with_modules(). Node names must be unique (current behavior)
- .with_functions() + .allow_module_overrides() means the content from .with_functions() is applied last.
- nodes produced by .with_functions() are overriden in order if duplicates
diff --git a/dag_example_module.png b/dag_example_module.png
index 5351fb7..6bacf46 100644
--- a/dag_example_module.png
+++ b/dag_example_module.png
Binary files differ
diff --git a/hamilton/async_driver.py b/hamilton/async_driver.py
index 63d3a8b..1c79ce6 100644
--- a/hamilton/async_driver.py
+++ b/hamilton/async_driver.py
@@ -5,7 +5,7 @@
import time
import typing
import uuid
-from types import ModuleType
+from types import FunctionType, ModuleType
from typing import Any, Dict, Optional, Tuple
import hamilton.lifecycle.base as lifecycle_base
@@ -199,6 +199,7 @@
result_builder: Optional[base.ResultMixin] = None,
adapters: typing.List[lifecycle.LifecycleAdapter] = None,
allow_module_overrides: bool = False,
+ functions: typing.List[FunctionType] = None,
):
"""Instantiates an asynchronous driver.
@@ -249,6 +250,7 @@
*async_adapters, # note async adapters will not be called during synchronous execution -- this is for access later
],
allow_module_overrides=allow_module_overrides,
+ functions=functions,
)
self.initialized = False
diff --git a/hamilton/driver.py b/hamilton/driver.py
index 868764e..0646673 100644
--- a/hamilton/driver.py
+++ b/hamilton/driver.py
@@ -13,7 +13,7 @@
import typing
import uuid
from datetime import datetime
-from types import ModuleType
+from types import FunctionType, ModuleType
from typing import (
Any,
Callable,
@@ -402,6 +402,7 @@
self,
config: Dict[str, Any],
*modules: ModuleType,
+ functions: List[FunctionType] = None,
adapter: Optional[
Union[lifecycle_base.LifecycleAdapter, List[lifecycle_base.LifecycleAdapter]]
] = None,
@@ -435,13 +436,15 @@
if adapter.does_hook("pre_do_anything", is_async=False):
adapter.call_all_lifecycle_hooks_sync("pre_do_anything")
error = None
+ self.graph_functions = functions if functions is not None else []
self.graph_modules = modules
try:
- self.graph = graph.FunctionGraph.from_modules(
- *modules,
+ self.graph = graph.FunctionGraph.compile(
+ modules=list(modules),
+ functions=functions if functions is not None else [],
config=config,
adapter=adapter,
- allow_module_overrides=allow_module_overrides,
+ allow_node_overrides=allow_module_overrides,
)
if _materializers:
materializer_factories, extractor_factories = self._process_materializers(
@@ -1866,6 +1869,7 @@
# common fields
self.config = {}
self.modules = []
+ self.functions = []
self.materializers = []
# Allow later modules to override nodes of the same name
@@ -1927,6 +1931,17 @@
self.modules.extend(modules)
return self
+ def with_functions(self, *functions: FunctionType) -> "Builder":
+ """Adds the specified functions to the list.
+ This can be called multiple times. If you have allow_module_overrides
+ set this will enabl overwriting modules or previously added functions.
+
+ :param functions:
+ :return: self
+ """
+ self.functions.extend(functions)
+ return self
+
def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> "Builder":
"""Sets the adapter to use.
@@ -2168,6 +2183,7 @@
_graph_executor=graph_executor,
_use_legacy_adapter=False,
allow_module_overrides=self._allow_module_overrides,
+ functions=self.functions,
)
def copy(self) -> "Builder":
diff --git a/hamilton/graph.py b/hamilton/graph.py
index 43ccd24..68ae94d 100644
--- a/hamilton/graph.py
+++ b/hamilton/graph.py
@@ -13,7 +13,7 @@
import pathlib
import uuid
from enum import Enum
-from types import ModuleType
+from types import FunctionType, ModuleType
from typing import Any, Callable, Collection, Dict, FrozenSet, List, Optional, Set, Tuple, Type
import hamilton.lifecycle.base as lifecycle_base
@@ -142,17 +142,18 @@
return nodes
-def create_function_graph(
+def compile_to_nodes(
*functions: List[Tuple[str, Callable]],
config: Dict[str, Any],
adapter: lifecycle_base.LifecycleAdapterSet = None,
fg: Optional["FunctionGraph"] = None,
- allow_module_overrides: bool = False,
+ allow_node_level_overrides: bool = False,
) -> Dict[str, node.Node]:
"""Creates a graph of all available functions & their dependencies.
:param modules: A set of modules over which one wants to compute the function graph
:param config: Dictionary that we will inspect to get values from in building the function graph.
:param adapter: The adapter that adapts our node type checking based on the context.
+ :param allow_node_level_overrides: Whether or not to allow node names to override each other
:return: list of nodes in the graph.
If it needs to be more complicated, we'll return an actual networkx graph and get all the rest of the logic for free
"""
@@ -170,7 +171,7 @@
for n in fm_base.resolve_nodes(f, config):
if n.name in config:
continue # This makes sure we overwrite things if they're in the config...
- if n.name in nodes and not allow_module_overrides:
+ if n.name in nodes and not allow_node_level_overrides:
raise ValueError(
f"Cannot define function {n.name} more than once."
f" Already defined by function {f}"
@@ -714,12 +715,41 @@
self.adapter = adapter
@staticmethod
+ def compile(
+ modules: List[ModuleType],
+ functions: List[FunctionType],
+ config: Dict[str, Any],
+ adapter: lifecycle_base.LifecycleAdapterSet = None,
+ allow_node_overrides: bool = False,
+ ) -> "FunctionGraph":
+ """Base level static function for compiling a function graph. Note
+ that this can both use functions (E.G. passing them directly) and modules
+ (passing them in and crawling.
+
+ :param modules: Modules to use
+ :param functions: Functions to use
+ :param config: Config to use for setting up the DAG
+ :param adapter: Adapter to use for node resolution
+ :param allow_node_overrides: Whether or not to allow node level overrides.
+ :return: The compiled function graph
+ """
+ module_functions = sum([find_functions(module) for module in modules], [])
+ nodes = compile_to_nodes(
+ *module_functions,
+ *functions,
+ config=config,
+ adapter=adapter,
+ allow_node_level_overrides=allow_node_overrides,
+ )
+ return FunctionGraph(nodes, config, adapter)
+
+ @staticmethod
def from_modules(
*modules: ModuleType,
config: Dict[str, Any],
adapter: lifecycle_base.LifecycleAdapterSet = None,
allow_module_overrides: bool = False,
- ):
+ ) -> "FunctionGraph":
"""Initializes a function graph from the specified modules. Note that this was the old
way we constructed FunctionGraph -- this is not a public-facing API, so we replaced it
with a constructor that takes in nodes directly. If you hacked in something using
@@ -732,28 +762,28 @@
:return: a function graph.
"""
- functions = sum([find_functions(module) for module in modules], [])
- return FunctionGraph.from_functions(
- *functions,
+ return FunctionGraph.compile(
+ modules=modules,
+ functions=[],
config=config,
adapter=adapter,
- allow_module_overrides=allow_module_overrides,
+ allow_node_overrides=allow_module_overrides,
)
@staticmethod
def from_functions(
- *functions,
+ *functions: FunctionType,
config: Dict[str, Any],
adapter: lifecycle_base.LifecycleAdapterSet = None,
allow_module_overrides: bool = False,
) -> "FunctionGraph":
- nodes = create_function_graph(
- *functions,
+ return FunctionGraph.compile(
+ modules=[],
+ functions=functions,
config=config,
adapter=adapter,
- allow_module_overrides=allow_module_overrides,
+ allow_node_overrides=allow_module_overrides,
)
- return FunctionGraph(nodes, config, adapter)
def with_nodes(self, nodes: Dict[str, Node]) -> "FunctionGraph":
"""Creates a new function graph with the additional specified nodes.
diff --git a/pyproject.toml b/pyproject.toml
index 24f8da0..e7945e1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -57,8 +57,7 @@
"diskcache",
# required for all the plugins
"dlt",
- # furo -- install from main for now until the next release is out:
- "furo @ git+https://github.com/pradyunsg/furo@main",
+ "furo",
"gitpython", # Required for parsing git info for generation of data-adapter docs
"grpcio-status",
"lightgbm",