Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add method to list leaf nodes #188

Merged
merged 7 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from collections.abc import Callable, Hashable, Iterable, Sequence
from itertools import chain
from types import UnionType
from typing import TYPE_CHECKING, Any, TypeVar, get_args, get_type_hints, overload
from typing import TYPE_CHECKING, Any, TypeVar, cast, get_args, get_type_hints, overload

from ._provider import Provider, ToProvider
from ._utils import key_name
from .data_graph import DataGraph, to_task_graph
from .display import pipeline_html_repr
from .handler import ErrorHandler, HandleAsComputeTimeException
from .handler import ErrorHandler, HandleAsComputeTimeException, UnsatisfiedRequirement
from .scheduler import Scheduler
from .task_graph import TaskGraph
from .typing import Key
Expand Down Expand Up @@ -89,7 +90,9 @@ def compute(self, tp: type | Iterable[type] | UnionType, **kwargs: Any) -> Any:
"""
return self.get(tp, **kwargs).compute()

def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digraph:
def visualize(
self, tp: type | Iterable[type] | None = None, **kwargs: Any
) -> graphviz.Digraph:
"""
Return a graphviz Digraph object representing the graph for the given keys.

Expand All @@ -103,6 +106,8 @@ def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digrap
kwargs:
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
if tp is None:
tp = self.final_result_keys()
return self.get(tp, handler=HandleAsComputeTimeException()).visualize(**kwargs)

def get(
Expand Down Expand Up @@ -135,7 +140,11 @@ def get(
targets = tuple(keys) # type: ignore[arg-type]
else:
targets = (keys,) # type: ignore[assignment]
graph = to_task_graph(self, targets=targets, handler=handler)
try:
graph = to_task_graph(self, targets=targets, handler=handler)
except UnsatisfiedRequirement as e:
final_result_keys = ", ".join(map(repr, self.final_result_keys()))
raise type(e)(f'Did you meant one of: {final_result_keys}?') from e
return TaskGraph(
graph=graph,
targets=targets if multi else keys, # type: ignore[arg-type]
Expand Down Expand Up @@ -200,6 +209,15 @@ def _repr_html_(self) -> str:
nodes = ((key, data) for key, data in self.underlying_graph.nodes.items())
return pipeline_html_repr(nodes)

def final_result_keys(self) -> tuple[type, ...]:
"""Returns the keys that are not inputs to any other providers."""
sink_nodes = [
cast(type, node)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand. What is this cast doing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just for appeasing mypy.
We promise the return type of the function is tuple[type, ...] because we know our keys are types.
But what we get from the underlying_graph is Any.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay took another look now and it was actually not necessary to cast in case we just specified the return as tuple[Key, ...] instead

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. Our keys are, in general not type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using type in a lot of other places instead of Key though, that's what confused me.
For example https://github.com/scipp/sciline/blob/main/src/sciline/pipeline.py#L108-L111

Is that just a mistake or is there a difference between the contexts that makes type the right annotation there?

for node, degree in self.underlying_graph.out_degree
if degree == 0
]
return tuple(sorted(sink_nodes, key=key_name))


def get_mapped_node_names(
graph: DataGraph, base_name: type, *, index_names: Sequence[Hashable] | None = None
Expand Down
26 changes: 26 additions & 0 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest

import sciline as sl
from sciline._utils import key_name


def int_to_float(x: int) -> float:
Expand Down Expand Up @@ -1414,3 +1415,28 @@ def bad(x: int, y: int) -> float:
pipeline = sl.Pipeline()
with pytest.raises(ValueError, match="Duplicate type hints"):
pipeline.insert(bad)


def test_final_result_keys_method() -> None:
def make_float() -> float:
return 1.0

def make_str(x: int) -> str:
return "a string"

assert sl.Pipeline([make_float, make_str]).final_result_keys() == (float, str)


@pytest.mark.parametrize('get_method', ['get', 'compute'])
def test_final_result_keys_in_not_found_error_message(get_method) -> None:
def make_float() -> float:
return 1.0

def make_str(x: int) -> str:
return "a string"

pl = sl.Pipeline([make_float, make_str])
with pytest.raises(sl.handler.UnsatisfiedRequirement) as info:
getattr(pl, get_method)(int)
for key in pl.final_result_keys():
assert key_name(key) in info.value.args[0]
8 changes: 8 additions & 0 deletions tests/visualize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,11 @@ class SubA(A[T]):
assert sl.visualize._format_type(A[float]).name == 'A[float]'
assert sl.visualize._format_type(SubA[float]).name == 'SubA[float]'
assert sl.visualize._format_type(B[float]).name == 'B[float]'


def test_can_visualize_graph_without_explicit_target() -> None:
def int_to_float(x: int) -> float:
return float(x)

pipeline = sl.Pipeline([int_to_float])
pipeline.visualize()
Loading