Skip to content

Commit

Permalink
feat: display results in 'not found' error
Browse files Browse the repository at this point in the history
  • Loading branch information
jokasimr committed Oct 30, 2024
1 parent 93af61b commit 49bc710
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
18 changes: 15 additions & 3 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from __future__ import annotations

import re
from collections.abc import Callable, Hashable, Iterable, Sequence
from itertools import chain
from types import UnionType
from typing import TYPE_CHECKING, Any, TypeVar, cast, get_args, get_type_hints, overload

from networkx import NetworkXError

from ._provider import Provider, ToProvider
from ._utils import key_name
from .data_graph import DataGraph, to_task_graph
Expand Down Expand Up @@ -107,7 +110,7 @@ def visualize(
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
if tp is None:
tp = self.leafs()
tp = self.final_result_keys()
return self.get(tp, handler=HandleAsComputeTimeException()).visualize(**kwargs)

def get(
Expand Down Expand Up @@ -140,7 +143,16 @@ def get(
targets = tuple(keys) # type: ignore[arg-type]
else:
targets = (keys,) # type: ignore[assignment]
graph = to_task_graph(self, targets=targets, handler=handler)
try:
graph = to_task_graph(self, targets=targets, handler=handler)
except NetworkXError as e:
if re.match(r'^The node (.*) is not in the digraph.$', e.args[0]):
final_result_keys = ", ".join(map(repr, self.final_result_keys()))
raise KeyError(
f'The requested type {e.__cause__} was not found in the pipeline. '
f'Did you meant one of: {final_result_keys}?'
) from e
raise e
return TaskGraph(
graph=graph,
targets=targets if multi else keys, # type: ignore[arg-type]
Expand Down Expand Up @@ -205,7 +217,7 @@ def _repr_html_(self) -> str:
nodes = ((key, data) for key, data in self.underlying_graph.nodes.items())
return pipeline_html_repr(nodes)

def leafs(self) -> tuple[type, ...]:
def final_result_keys(self) -> tuple[type, ...]:
"""Returns the keys that are not inputs to any other providers."""
sink_nodes = [
cast(type, node)
Expand Down
4 changes: 2 additions & 2 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,11 +1416,11 @@ def bad(x: int, y: int) -> float:
pipeline.insert(bad)


def test_leafs_method() -> None:
def test_final_result_keys_method() -> None:
def make_float() -> float:
return 1.0

def make_str(x: int) -> str:
return "a string"

assert sl.Pipeline([make_float, make_str]).leafs() == (float, str)
assert sl.Pipeline([make_float, make_str]).final_result_keys() == (float, str)

0 comments on commit 49bc710

Please sign in to comment.