Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add method to list leaf nodes #188

Merged
merged 7 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from collections.abc import Callable, Hashable, Iterable, Sequence
from itertools import chain
from types import UnionType
from typing import TYPE_CHECKING, Any, TypeVar, get_args, get_type_hints, overload
from typing import TYPE_CHECKING, Any, TypeVar, cast, get_args, get_type_hints, overload

from ._provider import Provider, ToProvider
from ._utils import key_name
from .data_graph import DataGraph, to_task_graph
from .display import pipeline_html_repr
from .handler import ErrorHandler, HandleAsComputeTimeException
Expand Down Expand Up @@ -89,7 +90,9 @@ def compute(self, tp: type | Iterable[type] | UnionType, **kwargs: Any) -> Any:
"""
return self.get(tp, **kwargs).compute()

def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digraph:
def visualize(
self, tp: type | Iterable[type] | None = None, **kwargs: Any
) -> graphviz.Digraph:
"""
Return a graphviz Digraph object representing the graph for the given keys.

Expand All @@ -103,6 +106,8 @@ def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digrap
kwargs:
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
if tp is None:
tp = self.leafs()
return self.get(tp, handler=HandleAsComputeTimeException()).visualize(**kwargs)

def get(
Expand Down Expand Up @@ -200,6 +205,15 @@ def _repr_html_(self) -> str:
nodes = ((key, data) for key, data in self.underlying_graph.nodes.items())
return pipeline_html_repr(nodes)

def leafs(self) -> tuple[type, ...]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Can we consider making this a free function? We should avoid bloating the interface of Pipeline.
  • What is a leaf? Can we use some clearer terminology such as used in mathematics for DAGs, or by NetworkX? Sink nodes, for example

Copy link
Contributor Author

@jokasimr jokasimr Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a leaf? Can we use some clearer terminology such as used in mathematics for DAGs, or by NetworkX? Sink nodes, for example

I think sink_nodes is a better name than leaf, but they are both "leaking" the graph abstraction below Pipeline. Maybe we should have something more specific to the domain, like final_results or output_keys.

Can we consider making this a free function? We should avoid bloating the interface of Pipeline.

I'm not so sure about this. We could make it a free function, but what's the advantage of doing so? It's intrinsically tightly coupled to the Pipeline class. Having it be a method makes it easier for people to find it using dot-notation access.

"""Returns the keys that are not inputs to any other providers."""
sink_nodes = [
cast(type, node)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand. What is this cast doing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just for appeasing mypy.
We promise the return type of the function is tuple[type, ...] because we know our keys are types.
But what we get from the underlying_graph is Any.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay took another look now and it was actually not necessary to cast in case we just specified the return as tuple[Key, ...] instead

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. Our keys are, in general not type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using type in a lot of other places instead of Key though, that's what confused me.
For example https://github.com/scipp/sciline/blob/main/src/sciline/pipeline.py#L108-L111

Is that just a mistake or is there a difference between the contexts that makes type the right annotation there?

for node, degree in self.underlying_graph.out_degree
if degree == 0
]
return tuple(sorted(sink_nodes, key=key_name))


def get_mapped_node_names(
graph: DataGraph, base_name: type, *, index_names: Sequence[Hashable] | None = None
Expand Down
10 changes: 10 additions & 0 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,3 +1414,13 @@ def bad(x: int, y: int) -> float:
pipeline = sl.Pipeline()
with pytest.raises(ValueError, match="Duplicate type hints"):
pipeline.insert(bad)


def test_leafs_method() -> None:
def make_float() -> float:
return 1.0

def make_str(x: int) -> str:
return "a string"

assert sl.Pipeline([make_float, make_str]).leafs() == (float, str)
8 changes: 8 additions & 0 deletions tests/visualize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,11 @@ class SubA(A[T]):
assert sl.visualize._format_type(A[float]).name == 'A[float]'
assert sl.visualize._format_type(SubA[float]).name == 'SubA[float]'
assert sl.visualize._format_type(B[float]).name == 'B[float]'


def test_can_visualize_graph_without_explicit_target() -> None:
def int_to_float(x: int) -> float:
return float(x)

pipeline = sl.Pipeline([int_to_float])
pipeline.visualize()