-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add method to list leaf nodes #188
Changes from 2 commits
7ea4371
93af61b
ef2131e
e5c2b60
b8e5b23
a044e97
19b0700
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,9 +5,10 @@ | |
from collections.abc import Callable, Hashable, Iterable, Sequence | ||
from itertools import chain | ||
from types import UnionType | ||
from typing import TYPE_CHECKING, Any, TypeVar, get_args, get_type_hints, overload | ||
from typing import TYPE_CHECKING, Any, TypeVar, cast, get_args, get_type_hints, overload | ||
|
||
from ._provider import Provider, ToProvider | ||
from ._utils import key_name | ||
from .data_graph import DataGraph, to_task_graph | ||
from .display import pipeline_html_repr | ||
from .handler import ErrorHandler, HandleAsComputeTimeException | ||
|
@@ -89,7 +90,9 @@ def compute(self, tp: type | Iterable[type] | UnionType, **kwargs: Any) -> Any: | |
""" | ||
return self.get(tp, **kwargs).compute() | ||
|
||
def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digraph: | ||
def visualize( | ||
self, tp: type | Iterable[type] | None = None, **kwargs: Any | ||
) -> graphviz.Digraph: | ||
""" | ||
Return a graphviz Digraph object representing the graph for the given keys. | ||
|
||
|
@@ -103,6 +106,8 @@ def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digrap | |
kwargs: | ||
Keyword arguments passed to :py:class:`graphviz.Digraph`. | ||
""" | ||
if tp is None: | ||
tp = self.leafs() | ||
return self.get(tp, handler=HandleAsComputeTimeException()).visualize(**kwargs) | ||
|
||
def get( | ||
|
@@ -200,6 +205,15 @@ def _repr_html_(self) -> str: | |
nodes = ((key, data) for key, data in self.underlying_graph.nodes.items()) | ||
return pipeline_html_repr(nodes) | ||
|
||
def leafs(self) -> tuple[type, ...]: | ||
"""Returns the keys that are not inputs to any other providers.""" | ||
sink_nodes = [ | ||
cast(type, node) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I understand. What is this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's just for appeasing mypy. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay took another look now and it was actually not necessary to cast in case we just specified the return as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed. Our keys are, in general not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are using Is that just a mistake or is there a difference between the contexts that makes |
||
for node, degree in self.underlying_graph.out_degree | ||
if degree == 0 | ||
] | ||
return tuple(sorted(sink_nodes, key=key_name)) | ||
|
||
|
||
def get_mapped_node_names( | ||
graph: DataGraph, base_name: type, *, index_names: Sequence[Hashable] | None = None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pipeline
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
sink_nodes
is a better name thanleaf
, but they are both "leaking" the graph abstraction belowPipeline
. Maybe we should have something more specific to the domain, likefinal_results
oroutput_keys
.I'm not so sure about this. We could make it a free function, but what's the advantage of doing so? It's intrinsically tightly coupled to the
Pipeline
class. Having it be a method makes it easier for people to find it using dot-notation access.