Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add method to list leaf nodes #188

Merged
merged 7 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from typing import TYPE_CHECKING, Any, TypeVar, get_args, get_type_hints, overload

from ._provider import Provider, ToProvider
from ._utils import key_name
from .data_graph import DataGraph, to_task_graph
from .display import pipeline_html_repr
from .handler import ErrorHandler, HandleAsComputeTimeException
from .handler import ErrorHandler, HandleAsComputeTimeException, UnsatisfiedRequirement
from .scheduler import Scheduler
from .task_graph import TaskGraph
from .typing import Key
Expand Down Expand Up @@ -89,7 +90,9 @@ def compute(self, tp: type | Iterable[type] | UnionType, **kwargs: Any) -> Any:
"""
return self.get(tp, **kwargs).compute()

def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digraph:
def visualize(
self, tp: type | Iterable[type] | None = None, **kwargs: Any
) -> graphviz.Digraph:
"""
Return a graphviz Digraph object representing the graph for the given keys.

Expand All @@ -103,6 +106,8 @@ def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digrap
kwargs:
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
if tp is None:
tp = self.output_keys()
return self.get(tp, handler=HandleAsComputeTimeException()).visualize(**kwargs)

def get(
Expand Down Expand Up @@ -135,7 +140,11 @@ def get(
targets = tuple(keys) # type: ignore[arg-type]
else:
targets = (keys,) # type: ignore[assignment]
graph = to_task_graph(self, targets=targets, handler=handler)
try:
graph = to_task_graph(self, targets=targets, handler=handler)
except UnsatisfiedRequirement as e:
output_keys = ", ".join(map(repr, self.output_keys()))
raise type(e)(f'Did you meant one of: {output_keys}?') from e
return TaskGraph(
graph=graph,
targets=targets if multi else keys, # type: ignore[arg-type]
Expand Down Expand Up @@ -200,6 +209,13 @@ def _repr_html_(self) -> str:
nodes = ((key, data) for key, data in self.underlying_graph.nodes.items())
return pipeline_html_repr(nodes)

def output_keys(self) -> tuple[Key, ...]:
"""Returns the keys that are not inputs to any other providers."""
sink_nodes = [
node for node, degree in self.underlying_graph.out_degree if degree == 0
]
return tuple(sorted(sink_nodes, key=key_name))


def get_mapped_node_names(
graph: DataGraph, base_name: type, *, index_names: Sequence[Hashable] | None = None
Expand Down
26 changes: 26 additions & 0 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest

import sciline as sl
from sciline._utils import key_name


def int_to_float(x: int) -> float:
Expand Down Expand Up @@ -1414,3 +1415,28 @@ def bad(x: int, y: int) -> float:
pipeline = sl.Pipeline()
with pytest.raises(ValueError, match="Duplicate type hints"):
pipeline.insert(bad)


def test_output_keys_method() -> None:
def make_float() -> float:
return 1.0

def make_str(x: int) -> str:
return "a string"

assert sl.Pipeline([make_float, make_str]).output_keys() == (float, str)


@pytest.mark.parametrize('get_method', ['get', 'compute'])
def test_output_keys_in_not_found_error_message(get_method: str) -> None:
def make_float() -> float:
return 1.0

def make_str(x: int) -> str:
return "a string"

pl = sl.Pipeline([make_float, make_str])
with pytest.raises(sl.handler.UnsatisfiedRequirement) as info:
getattr(pl, get_method)(int)
for key in pl.output_keys():
assert key_name(key) in info.value.args[0]
8 changes: 8 additions & 0 deletions tests/visualize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,11 @@ class SubA(A[T]):
assert sl.visualize._format_type(A[float]).name == 'A[float]'
assert sl.visualize._format_type(SubA[float]).name == 'SubA[float]'
assert sl.visualize._format_type(B[float]).name == 'B[float]'


def test_can_visualize_graph_without_explicit_target() -> None:
def int_to_float(x: int) -> float:
return float(x)

pipeline = sl.Pipeline([int_to_float])
pipeline.visualize()