blob: de6c1217d73c3788053b5dc7e2ab67eae1b6041f [file] [log] [blame]
import dataclasses
import importlib
import inspect
import os
from typing import List, Optional, Tuple, Type
import git
from docutils import nodes
from docutils.parsers.rst import Directive
import hamilton.io.data_adapters
from hamilton import registry
"""A module to crawl available data adapters and generate documentation for them.
Note these currently link out to the source code on GitHub, but they should
be linking to the documentation instead, which hasn't been generated yet.
"""
# These have fallbacks for local dev
GIT_URL = os.environ.get("READTHEDOCS_GIT_CLONE_URL", "https://github.com/dagworks-inc/hamilton")
GIT_ID = os.environ.get("READTHEDOCS_GIT_IDENTIFIER", "main")
# All the modules that register data adapters
# When you register a new one, add it here
# Keep these two around so we
importlib.import_module("hamilton.function_modifiers.base")
importlib.import_module("hamilton.io")
def get_git_root(path: str) -> str:
"""Yields the git room of a repo, given an absolute path to
a file within the repo.
:param path: Path to a file within a git repo
:return: The root of the git repo
"""
git_repo = git.Repo(path, search_parent_directories=True)
git_root = git_repo.git.rev_parse("--show-toplevel")
return git_root
@dataclasses.dataclass
class Param:
name: str
type: str
default: Optional[str] = None
def get_default(param: dataclasses.Field) -> Optional[str]:
"""Gets the deafult of a dataclass field, if it has one.
:param param: The dataclass field
:return: The str representation of the default.
"""
if param.default is dataclasses.MISSING:
return None
return str(param.default)
def get_lines_for_class(class_: Type[Type]) -> Tuple[int, int]:
"""Gets the set of lines in which a class is implemented
:param class_: The class to get the lines for
:return: A tuple of the start and end lines
"""
lines = inspect.getsourcelines(class_)
start_line = lines[1]
end_line = lines[1] + len(lines[0])
return start_line, end_line
def get_class_repr(class_: Type) -> str:
"""Gets a representation of a class that can be used in documentation.
:param class_: Python class to get the representation for
:return: Str representation
"""
try:
return class_.__qualname__
except AttributeError:
# This happens when we have generics or other oddities
return str(class_)
@dataclasses.dataclass
class AdapterInfo:
key: str
class_name: str
class_path: str
load_params: List[Param]
save_params: List[Param]
applicable_types: List[str]
file_: str
line_nos: Tuple[int, int]
@staticmethod
def from_loader(loader: Type[hamilton.io.data_adapters.DataLoader]) -> "AdapterInfo":
"""Utility constructor to create the AdapterInfo from a DataLoader class
:param loader: DataLoader class
:return: AdapterInfo derived from it
"""
return AdapterInfo(
key=loader.name(),
class_name=loader.__name__,
class_path=loader.__module__,
load_params=[
Param(name=p.name, type=get_class_repr(p.type), default=get_default(p))
for p in dataclasses.fields(loader)
]
if issubclass(loader, hamilton.io.data_adapters.DataLoader)
else None,
save_params=[
Param(name=p.name, type=get_class_repr(p.type), default=get_default(p))
for p in dataclasses.fields(loader)
]
if issubclass(loader, hamilton.io.data_adapters.DataSaver)
else None,
applicable_types=[get_class_repr(t) for t in loader.applicable_types()],
file_=inspect.getfile(loader),
line_nos=get_lines_for_class(loader),
)
def _collect_loaders(saver_or_loader: str) -> List[Type[hamilton.io.data_adapters.AdapterCommon]]:
"""Collects all loaders from the registry.
:return:
"""
out = []
loaders = (
list(registry.LOADER_REGISTRY.values())
if saver_or_loader == "loader"
else list(registry.SAVER_REGISTRY.values())
)
for classes in loaders:
for cls in classes:
if cls not in out:
out.append(cls)
return out
# Utility functions to render different components of the adapter in table cells
def render_key(key: str):
return [nodes.Text(key, key)]
def render_class_name(class_name: str):
return [nodes.literal(text=class_name)]
def render_class_path(class_path: str, file_: str, line_start: int, line_end: int):
git_path = get_git_root(file_)
file_relative_to_git_root = os.path.relpath(file_, git_path)
href = f"{GIT_URL}/blob/{GIT_ID}/{file_relative_to_git_root}#L{line_start}-L{line_end}"
# href = f"{GIT_URL}/blob/{GIT_ID}/{file_}#L{line_no}"
return [nodes.raw("", f'<a href="{href}">{class_path}</a>', format="html")]
def render_adapter_params(load_params: Optional[List[Param]]):
if load_params is None:
return nodes.raw("", "<div/>", format="html")
fieldlist = nodes.field_list()
for i, load_param in enumerate(load_params):
fieldname = nodes.Text(load_param.name)
fieldbody = nodes.literal(
text=load_param.type
+ ("=" + load_param.default if load_param.default is not None else "")
)
field = nodes.field("", fieldname, fieldbody)
fieldlist += field
if i < len(load_params) - 1:
fieldlist += nodes.raw("", "<br/>", format="html")
return fieldlist
def render_applicable_types(applicable_types: List[str]):
fieldlist = nodes.field_list()
for applicable_type in applicable_types:
fieldlist += nodes.field("", nodes.literal(text=applicable_type), nodes.Text(""))
fieldlist += nodes.raw("", "<br/>", format="html")
return fieldlist
class DataAdapterTableDirective(Directive):
"""Custom directive to render a table of all data adapters. Takes in one argument
that is either 'loader' or 'saver' to indicate which adapters to render."""
has_content = True
required_arguments = 1 # Number of required arguments
def run(self):
"""Runs the directive. This does the following:
1. Collects all loaders from the registry
2. Creates a table with the following columns:
- Key
- Class name
- Class path
- Load params
- Applicable types
3. Returns the table
:return: A list of nodes that Sphinx will render, consisting of the table node
"""
saver_or_loader = self.arguments[0]
if saver_or_loader not in ("loader", "saver"):
raise ValueError(
f"loader_or_saver must be one of 'loader' or 'saver', " f"got {saver_or_loader}"
)
table_data = [
AdapterInfo.from_loader(loader) for loader in _collect_loaders(saver_or_loader)
]
# Create the table and add columns
table_node = nodes.table()
tgroup = nodes.tgroup(cols=6)
table_node += tgroup
# Create columns
key_spec = nodes.colspec(colwidth=1)
# class_spec = nodes.colspec(colwidth=1)
load_params_spec = nodes.colspec(colwidth=2)
applicable_types_spec = nodes.colspec(colwidth=1)
class_path_spec = nodes.colspec(colwidth=1)
tgroup += [key_spec, load_params_spec, applicable_types_spec, class_path_spec]
# Create the table body
thead = nodes.thead()
row = nodes.row()
# Create entry nodes for each cell
key_entry = nodes.entry()
load_params_entry = nodes.entry()
applicable_types_entry = nodes.entry()
class_path_entry = nodes.entry()
key_entry += nodes.paragraph(text="key")
load_params_entry += nodes.paragraph(text=f"{saver_or_loader} params")
applicable_types_entry += nodes.paragraph(text="types")
class_path_entry += nodes.paragraph(text="module")
row += [key_entry, load_params_entry, applicable_types_entry, class_path_entry]
thead += row
tgroup += thead
tbody = nodes.tbody()
tgroup += tbody
# Populate table rows based on your table_data
for row_data in table_data:
row = nodes.row()
# Create entry nodes for each cell
key_entry = nodes.entry()
load_params_entry = nodes.entry()
applicable_types_entry = nodes.entry()
class_path_entry = nodes.entry()
# Create a paragraph node for each entry
# import pdb
# pdb.set_trace()
# para1 = nodes.literal(text=row_data['column1_data'])
# para2 = nodes.paragraph(text=row_data['column2_data'])
# Add the paragraph nodes to the entry nodes
key_entry += render_key(row_data.key)
load_params_entry += render_adapter_params(
row_data.load_params if saver_or_loader == "loader" else row_data.save_params
)
applicable_types_entry += render_applicable_types(row_data.applicable_types)
class_path_entry += render_class_path(
row_data.class_path, row_data.file_, *row_data.line_nos
)
# Add the entry nodes to the row
row += [key_entry, load_params_entry, applicable_types_entry, class_path_entry]
# Add the row to the table body
tbody += row
return [table_node]
def setup(app):
"""Required to register the extension"""
app.add_directive("data_adapter_table", DataAdapterTableDirective)