Skip to content

Commit

Permalink
Merge pull request #44 from scipp/compute-dict
Browse files Browse the repository at this point in the history
``compute`` returns ``dict`` instead of ``tuple``.
  • Loading branch information
YooSunYoung authored Aug 31, 2023
2 parents dec803d + 804df3f commit 13b079b
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def compute(self, tp: Type[T]) -> T:
...

@overload
def compute(self, tp: Tuple[Type[T], ...]) -> Tuple[T, ...]:
def compute(self, tp: Tuple[Type[T], ...]) -> Dict[Type[T], T]:
...

@overload
Expand Down
6 changes: 3 additions & 3 deletions src/sciline/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from typing import Any, Callable, Dict, List, Optional, Protocol
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple

from sciline.typing import Graph, Key

Expand All @@ -14,7 +14,7 @@ class Scheduler(Protocol):
Scheduler interface compatible with :py:class:`sciline.Pipeline`.
"""

def get(self, graph: Graph, keys: List[Key]) -> Any:
def get(self, graph: Graph, keys: List[Key]) -> Tuple[Any, ...]:
"""
Compute the result for given keys from the graph.
Expand All @@ -33,7 +33,7 @@ class NaiveScheduler:
:py:class:`DaskScheduler` instead.
"""

def get(self, graph: Graph, keys: List[Key]) -> Any:
def get(self, graph: Graph, keys: List[Key]) -> Tuple[Any, ...]:
import graphlib

dependencies = {tp: args for tp, (_, args) in graph.items()}
Expand Down
10 changes: 9 additions & 1 deletion src/sciline/task_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,19 @@ def compute(
Optional list of keys to compute. This can be used to override the keys
stored in the graph instance. Note that the keys must be present in the
graph as intermediate results, otherwise KeyError is raised.
Returns
-------
If ``keys`` is a single type, returns the single result that was computed.
If ``keys`` is a tuple of types, returns a dictionary with type as keys
and the corresponding results as values.
"""
if keys is None:
keys = self._keys
if isinstance(keys, tuple):
return self._scheduler.get(self._graph, list(keys))
results = self._scheduler.get(self._graph, list(keys))
return dict(zip(keys, results))
else:
return self._scheduler.get(self._graph, [keys])[0]

Expand Down
10 changes: 5 additions & 5 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def provide_int() -> int:
return 3

pipeline = sl.Pipeline([int_to_float, provide_int, int_float_to_str])
assert pipeline.compute((float, str)) == (1.5, "3;1.5")
assert pipeline.compute((float, str)) == {float: 1.5, str: "3;1.5"}
assert ncall == 1


Expand All @@ -77,7 +77,7 @@ def func2(x: int) -> str:
return f"{x}"

pipeline = sl.Pipeline([provide_int, func1, func2])
assert pipeline.compute((float, str)) == (1.5, "3")
assert pipeline.compute((float, str)) == {float: 1.5, str: "3"}
assert ncall == 1


Expand Down Expand Up @@ -588,10 +588,10 @@ def test_get_with_single_key_return_task_graph_that_computes_value() -> None:
assert task.compute() == '3;1.5'


def test_get_with_key_tuple_return_task_graph_that_computes_tuple_of_values() -> None:
def test_get_with_key_tuple_return_task_graph_that_computes_dict_of_values() -> None:
pipeline = sl.Pipeline([int_to_float, make_int])
task = pipeline.get((float, int))
assert task.compute() == (1.5, 3)
assert task.compute() == {float: 1.5, int: 3}


def test_task_graph_compute_can_override_single_key() -> None:
Expand All @@ -603,7 +603,7 @@ def test_task_graph_compute_can_override_single_key() -> None:
def test_task_graph_compute_can_override_key_tuple() -> None:
pipeline = sl.Pipeline([int_to_float, make_int])
task = pipeline.get(float)
assert task.compute((int, float)) == (3, 1.5)
assert task.compute((int, float)) == {int: 3, float: 1.5}


def test_task_graph_compute_raises_if_override_keys_outside_graph() -> None:
Expand Down
10 changes: 5 additions & 5 deletions tests/task_graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def test_compute_returns_value_when_initialized_with_single_key() -> None:
assert tg.compute() == 0.5


def test_compute_returns_tuple_when_initialized_with_key_tuple() -> None:
def test_compute_returns_dict_when_initialized_with_key_tuple() -> None:
graph = make_task_graph()
assert TaskGraph(graph=graph, keys=(float,)).compute() == (0.5,)
assert TaskGraph(graph=graph, keys=(float, int)).compute() == (0.5, 1)
assert TaskGraph(graph=graph, keys=(float,)).compute() == {float: 0.5}
assert TaskGraph(graph=graph, keys=(float, int)).compute() == {float: 0.5, int: 1}


def test_compute_returns_value_when_provided_with_single_key() -> None:
Expand All @@ -40,10 +40,10 @@ def test_compute_returns_value_when_provided_with_single_key() -> None:
assert tg.compute(int) == 1


def test_compute_returns_tuple_when_provided_with_key_tuple() -> None:
def test_compute_returns_dict_when_provided_with_key_tuple() -> None:
graph = make_task_graph()
tg = TaskGraph(graph=graph, keys=float)
assert tg.compute((int, float)) == (1, 0.5)
assert tg.compute((int, float)) == {int: 1, float: 0.5}


def test_compute_raises_when_provided_with_key_not_in_graph() -> None:
Expand Down

0 comments on commit 13b079b

Please sign in to comment.