| r"""Defines a docutils directive for inserting inheritance diagrams. |
| |
| Provide the directive with one or more classes or modules (separated |
| by whitespace). For modules, all of the classes in that module will |
| be used. |
| |
| Example:: |
| |
| Given the following classes: |
| |
| class A: pass |
| class B(A): pass |
| class C(A): pass |
| class D(B, C): pass |
| class E(B): pass |
| |
| .. inheritance-diagram: D E |
| |
| Produces a graph like the following: |
| |
| A |
| / \ |
| B C |
| / \ / |
| E D |
| |
| The graph is inserted as a PNG+image map into HTML and a PDF in |
| LaTeX. |
| """ |
| |
| from __future__ import annotations |
| |
| import builtins |
| import hashlib |
| import inspect |
| import re |
| from collections.abc import Iterable, Sequence |
| from importlib import import_module |
| from os import path |
| from typing import TYPE_CHECKING, Any, cast |
| |
| from docutils import nodes |
| from docutils.parsers.rst import directives |
| |
| import sphinx |
| from sphinx import addnodes |
| from sphinx.ext.graphviz import ( |
| figure_wrapper, |
| graphviz, |
| render_dot_html, |
| render_dot_latex, |
| render_dot_texinfo, |
| ) |
| from sphinx.util.docutils import SphinxDirective |
| |
| if TYPE_CHECKING: |
| from docutils.nodes import Node |
| |
| from sphinx.application import Sphinx |
| from sphinx.environment import BuildEnvironment |
| from sphinx.util.typing import OptionSpec |
| from sphinx.writers.html import HTML5Translator |
| from sphinx.writers.latex import LaTeXTranslator |
| from sphinx.writers.texinfo import TexinfoTranslator |
| |
| module_sig_re = re.compile(r'''^(?:([\w.]*)\.)? # module names |
| (\w+) \s* $ # class/final module name |
| ''', re.VERBOSE) |
| |
| |
| py_builtins = [obj for obj in vars(builtins).values() |
| if inspect.isclass(obj)] |
| |
| |
| def try_import(objname: str) -> Any: |
| """Import a object or module using *name* and *currentmodule*. |
| *name* should be a relative name from *currentmodule* or |
| a fully-qualified name. |
| |
| Returns imported object or module. If failed, returns None value. |
| """ |
| try: |
| return import_module(objname) |
| except TypeError: |
| # Relative import |
| return None |
| except ImportError: |
| matched = module_sig_re.match(objname) |
| |
| if not matched: |
| return None |
| |
| modname, attrname = matched.groups() |
| |
| if modname is None: |
| return None |
| try: |
| module = import_module(modname) |
| return getattr(module, attrname, None) |
| except ImportError: |
| return None |
| |
| |
| def import_classes(name: str, currmodule: str) -> Any: |
| """Import a class using its fully-qualified *name*.""" |
| target = None |
| |
| # import class or module using currmodule |
| if currmodule: |
| target = try_import(currmodule + '.' + name) |
| |
| # import class or module without currmodule |
| if target is None: |
| target = try_import(name) |
| |
| if target is None: |
| raise InheritanceException( |
| 'Could not import class or module %r specified for ' |
| 'inheritance diagram' % name) |
| |
| if inspect.isclass(target): |
| # If imported object is a class, just return it |
| return [target] |
| elif inspect.ismodule(target): |
| # If imported object is a module, return classes defined on it |
| classes = [] |
| for cls in target.__dict__.values(): |
| if inspect.isclass(cls) and cls.__module__ == target.__name__: |
| classes.append(cls) |
| return classes |
| raise InheritanceException('%r specified for inheritance diagram is ' |
| 'not a class or module' % name) |
| |
| |
| class InheritanceException(Exception): |
| pass |
| |
| |
| class InheritanceGraph: |
| """ |
| Given a list of classes, determines the set of classes that they inherit |
| from all the way to the root "object", and then is able to generate a |
| graphviz dot graph from them. |
| """ |
| def __init__(self, class_names: list[str], currmodule: str, show_builtins: bool = False, |
| private_bases: bool = False, parts: int = 0, |
| aliases: dict[str, str] | None = None, top_classes: Sequence[Any] = (), |
| ) -> None: |
| """*class_names* is a list of child classes to show bases from. |
| |
| If *show_builtins* is True, then Python builtins will be shown |
| in the graph. |
| """ |
| self.class_names = class_names |
| classes = self._import_classes(class_names, currmodule) |
| self.class_info = self._class_info(classes, show_builtins, |
| private_bases, parts, aliases, top_classes) |
| if not self.class_info: |
| msg = 'No classes found for inheritance diagram' |
| raise InheritanceException(msg) |
| |
| def _import_classes(self, class_names: list[str], currmodule: str) -> list[Any]: |
| """Import a list of classes.""" |
| classes: list[Any] = [] |
| for name in class_names: |
| classes.extend(import_classes(name, currmodule)) |
| return classes |
| |
| def _class_info(self, classes: list[Any], show_builtins: bool, private_bases: bool, |
| parts: int, aliases: dict[str, str] | None, top_classes: Sequence[Any], |
| ) -> list[tuple[str, str, list[str], str]]: |
| """Return name and bases for all classes that are ancestors of |
| *classes*. |
| |
| *parts* gives the number of dotted name parts to include in the |
| displayed node names, from right to left. If given as a negative, the |
| number of parts to drop from the left. A value of 0 displays the full |
| dotted name. E.g. ``sphinx.ext.inheritance_diagram.InheritanceGraph`` |
| with ``parts=2`` or ``parts=-2`` gets displayed as |
| ``inheritance_diagram.InheritanceGraph``, and as |
| ``ext.inheritance_diagram.InheritanceGraph`` with ``parts=3`` or |
| ``parts=-1``. |
| |
| *top_classes* gives the name(s) of the top most ancestor class to |
| traverse to. Multiple names can be specified separated by comma. |
| """ |
| all_classes = {} |
| |
| def recurse(cls: Any) -> None: |
| if not show_builtins and cls in py_builtins: |
| return |
| if not private_bases and cls.__name__.startswith('_'): |
| return |
| |
| nodename = self.class_name(cls, parts, aliases) |
| fullname = self.class_name(cls, 0, aliases) |
| |
| # Use first line of docstring as tooltip, if available |
| tooltip = None |
| try: |
| if cls.__doc__: |
| doc = cls.__doc__.strip().split("\n")[0] |
| if doc: |
| tooltip = '"%s"' % doc.replace('"', '\\"') |
| except Exception: # might raise AttributeError for strange classes |
| pass |
| |
| baselist: list[str] = [] |
| all_classes[cls] = (nodename, fullname, baselist, tooltip) |
| |
| if fullname in top_classes: |
| return |
| |
| for base in cls.__bases__: |
| if not show_builtins and base in py_builtins: |
| continue |
| if not private_bases and base.__name__.startswith('_'): |
| continue |
| baselist.append(self.class_name(base, parts, aliases)) |
| if base not in all_classes: |
| recurse(base) |
| |
| for cls in classes: |
| recurse(cls) |
| |
| return list(all_classes.values()) # type: ignore[arg-type] |
| |
| def class_name( |
| self, cls: Any, parts: int = 0, aliases: dict[str, str] | None = None, |
| ) -> str: |
| """Given a class object, return a fully-qualified name. |
| |
| This works for things I've tested in matplotlib so far, but may not be |
| completely general. |
| """ |
| module = cls.__module__ |
| if module in ('__builtin__', 'builtins'): |
| fullname = cls.__name__ |
| else: |
| fullname = f'{module}.{cls.__qualname__}' |
| if parts == 0: |
| result = fullname |
| else: |
| name_parts = fullname.split('.') |
| result = '.'.join(name_parts[-parts:]) |
| if aliases is not None and result in aliases: |
| return aliases[result] |
| return result |
| |
| def get_all_class_names(self) -> list[str]: |
| """Get all of the class names involved in the graph.""" |
| return [fullname for (_, fullname, _, _) in self.class_info] |
| |
| # These are the default attrs for graphviz |
| default_graph_attrs = { |
| 'rankdir': 'LR', |
| 'size': '"8.0, 12.0"', |
| 'bgcolor': 'transparent', |
| } |
| default_node_attrs = { |
| 'shape': 'box', |
| 'fontsize': 10, |
| 'height': 0.25, |
| 'fontname': '"Vera Sans, DejaVu Sans, Liberation Sans, ' |
| 'Arial, Helvetica, sans"', |
| 'style': '"setlinewidth(0.5),filled"', |
| 'fillcolor': 'white', |
| } |
| default_edge_attrs = { |
| 'arrowsize': 0.5, |
| 'style': '"setlinewidth(0.5)"', |
| } |
| |
| def _format_node_attrs(self, attrs: dict[str, Any]) -> str: |
| return ','.join(['%s=%s' % x for x in sorted(attrs.items())]) |
| |
| def _format_graph_attrs(self, attrs: dict[str, Any]) -> str: |
| return ''.join(['%s=%s;\n' % x for x in sorted(attrs.items())]) |
| |
| def generate_dot(self, name: str, urls: dict[str, str] | None = None, |
| env: BuildEnvironment | None = None, |
| graph_attrs: dict | None = None, |
| node_attrs: dict | None = None, |
| edge_attrs: dict | None = None, |
| ) -> str: |
| """Generate a graphviz dot graph from the classes that were passed in |
| to __init__. |
| |
| *name* is the name of the graph. |
| |
| *urls* is a dictionary mapping class names to HTTP URLs. |
| |
| *graph_attrs*, *node_attrs*, *edge_attrs* are dictionaries containing |
| key/value pairs to pass on as graphviz properties. |
| """ |
| if urls is None: |
| urls = {} |
| g_attrs = self.default_graph_attrs.copy() |
| n_attrs = self.default_node_attrs.copy() |
| e_attrs = self.default_edge_attrs.copy() |
| if graph_attrs is not None: |
| g_attrs.update(graph_attrs) |
| if node_attrs is not None: |
| n_attrs.update(node_attrs) |
| if edge_attrs is not None: |
| e_attrs.update(edge_attrs) |
| if env: |
| g_attrs.update(env.config.inheritance_graph_attrs) |
| n_attrs.update(env.config.inheritance_node_attrs) |
| e_attrs.update(env.config.inheritance_edge_attrs) |
| |
| res: list[str] = [] |
| res.append('digraph %s {\n' % name) |
| res.append(self._format_graph_attrs(g_attrs)) |
| |
| for name, fullname, bases, tooltip in sorted(self.class_info): |
| # Write the node |
| this_node_attrs = n_attrs.copy() |
| if fullname in urls: |
| this_node_attrs['URL'] = '"%s"' % urls[fullname] |
| this_node_attrs['target'] = '"_top"' |
| if tooltip: |
| this_node_attrs['tooltip'] = tooltip |
| res.append(' "%s" [%s];\n' % |
| (name, self._format_node_attrs(this_node_attrs))) |
| |
| # Write the edges |
| for base_name in bases: |
| res.append(' "%s" -> "%s" [%s];\n' % |
| (base_name, name, |
| self._format_node_attrs(e_attrs))) |
| res.append('}\n') |
| return ''.join(res) |
| |
| |
| class inheritance_diagram(graphviz): |
| """ |
| A docutils node to use as a placeholder for the inheritance diagram. |
| """ |
| pass |
| |
| |
| class InheritanceDiagram(SphinxDirective): |
| """ |
| Run when the inheritance_diagram directive is first encountered. |
| """ |
| has_content = False |
| required_arguments = 1 |
| optional_arguments = 0 |
| final_argument_whitespace = True |
| option_spec: OptionSpec = { |
| 'parts': int, |
| 'private-bases': directives.flag, |
| 'caption': directives.unchanged, |
| 'top-classes': directives.unchanged_required, |
| } |
| |
| def run(self) -> list[Node]: |
| node = inheritance_diagram() |
| node.document = self.state.document |
| class_names = self.arguments[0].split() |
| class_role = self.env.get_domain('py').role('class') |
| # Store the original content for use as a hash |
| node['parts'] = self.options.get('parts', 0) |
| node['content'] = ', '.join(class_names) |
| node['top-classes'] = [] |
| for cls in self.options.get('top-classes', '').split(','): |
| cls = cls.strip() |
| if cls: |
| node['top-classes'].append(cls) |
| |
| # Create a graph starting with the list of classes |
| try: |
| graph = InheritanceGraph( |
| class_names, self.env.ref_context.get('py:module'), # type: ignore[arg-type] |
| parts=node['parts'], |
| private_bases='private-bases' in self.options, |
| aliases=self.config.inheritance_alias, |
| top_classes=node['top-classes']) |
| except InheritanceException as err: |
| return [node.document.reporter.warning(err, line=self.lineno)] |
| |
| # Create xref nodes for each target of the graph's image map and |
| # add them to the doc tree so that Sphinx can resolve the |
| # references to real URLs later. These nodes will eventually be |
| # removed from the doctree after we're done with them. |
| for name in graph.get_all_class_names(): |
| refnodes, x = class_role( # type: ignore[call-arg,misc] |
| 'class', ':class:`%s`' % name, name, 0, self.state) # type: ignore[arg-type] |
| node.extend(refnodes) |
| # Store the graph object so we can use it to generate the |
| # dot file later |
| node['graph'] = graph |
| |
| if 'caption' not in self.options: |
| self.add_name(node) |
| return [node] |
| else: |
| figure = figure_wrapper(self, node, self.options['caption']) |
| self.add_name(figure) |
| return [figure] |
| |
| |
| def get_graph_hash(node: inheritance_diagram) -> str: |
| encoded = (node['content'] + str(node['parts'])).encode() |
| return hashlib.md5(encoded, usedforsecurity=False).hexdigest()[-10:] |
| |
| |
| def html_visit_inheritance_diagram(self: HTML5Translator, node: inheritance_diagram) -> None: |
| """ |
| Output the graph for HTML. This will insert a PNG with clickable |
| image map. |
| """ |
| graph = node['graph'] |
| |
| graph_hash = get_graph_hash(node) |
| name = 'inheritance%s' % graph_hash |
| |
| # Create a mapping from fully-qualified class names to URLs. |
| graphviz_output_format = self.builder.env.config.graphviz_output_format.upper() |
| current_filename = path.basename(self.builder.current_docname + self.builder.out_suffix) |
| urls = {} |
| pending_xrefs = cast(Iterable[addnodes.pending_xref], node) |
| for child in pending_xrefs: |
| if child.get('refuri') is not None: |
| # Construct the name from the URI if the reference is external via intersphinx |
| if not child.get('internal', True): |
| refname = child['refuri'].rsplit('#', 1)[-1] |
| else: |
| refname = child['reftitle'] |
| |
| urls[refname] = child.get('refuri') |
| elif child.get('refid') is not None: |
| if graphviz_output_format == 'SVG': |
| urls[child['reftitle']] = current_filename + '#' + child.get('refid') |
| else: |
| urls[child['reftitle']] = '#' + child.get('refid') |
| |
| dotcode = graph.generate_dot(name, urls, env=self.builder.env) |
| render_dot_html(self, node, dotcode, {}, 'inheritance', 'inheritance', |
| alt='Inheritance diagram of ' + node['content']) |
| raise nodes.SkipNode |
| |
| |
| def latex_visit_inheritance_diagram(self: LaTeXTranslator, node: inheritance_diagram) -> None: |
| """ |
| Output the graph for LaTeX. This will insert a PDF. |
| """ |
| graph = node['graph'] |
| |
| graph_hash = get_graph_hash(node) |
| name = 'inheritance%s' % graph_hash |
| |
| dotcode = graph.generate_dot(name, env=self.builder.env, |
| graph_attrs={'size': '"6.0,6.0"'}) |
| render_dot_latex(self, node, dotcode, {}, 'inheritance') |
| raise nodes.SkipNode |
| |
| |
| def texinfo_visit_inheritance_diagram(self: TexinfoTranslator, node: inheritance_diagram, |
| ) -> None: |
| """ |
| Output the graph for Texinfo. This will insert a PNG. |
| """ |
| graph = node['graph'] |
| |
| graph_hash = get_graph_hash(node) |
| name = 'inheritance%s' % graph_hash |
| |
| dotcode = graph.generate_dot(name, env=self.builder.env, |
| graph_attrs={'size': '"6.0,6.0"'}) |
| render_dot_texinfo(self, node, dotcode, {}, 'inheritance') |
| raise nodes.SkipNode |
| |
| |
| def skip(self: nodes.NodeVisitor, node: inheritance_diagram) -> None: |
| raise nodes.SkipNode |
| |
| |
| def setup(app: Sphinx) -> dict[str, Any]: |
| app.setup_extension('sphinx.ext.graphviz') |
| app.add_node( |
| inheritance_diagram, |
| latex=(latex_visit_inheritance_diagram, None), |
| html=(html_visit_inheritance_diagram, None), |
| text=(skip, None), |
| man=(skip, None), |
| texinfo=(texinfo_visit_inheritance_diagram, None)) |
| app.add_directive('inheritance-diagram', InheritanceDiagram) |
| app.add_config_value('inheritance_graph_attrs', {}, False) |
| app.add_config_value('inheritance_node_attrs', {}, False) |
| app.add_config_value('inheritance_edge_attrs', {}, False) |
| app.add_config_value('inheritance_alias', {}, False) |
| return {'version': sphinx.__display_version__, 'parallel_read_safe': True} |