diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index eb74611c..a53e70af 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -9,6 +9,7 @@ except importlib.metadata.PackageNotFoundError: __version__ = "0.0.0" +from . import scheduler from .domain import Scope from .param_table import ParamTable from .pipeline import ( @@ -27,4 +28,5 @@ "Scope", "UnboundTypeVar", "UnsatisfiedRequirement", + "scheduler", ] diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index e4a20834..cbdf64c5 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections import defaultdict +from itertools import chain from typing import ( Any, Callable, @@ -632,3 +633,57 @@ def get( else: graph = self.build(keys) return TaskGraph(graph=graph, keys=keys, scheduler=scheduler) + + @overload + def bind_and_call(self, fns: Callable[..., T], /) -> T: + ... + + @overload + def bind_and_call(self, fns: Iterable[Callable[..., Any]], /) -> Tuple[Any, ...]: + ... + + def bind_and_call( + self, fns: Union[Callable[..., Any], Iterable[Callable[..., Any]]], / + ) -> Any: + """ + Call the given functions with arguments provided by the pipeline. + + Parameters + ---------- + fns: + Functions to call. + The pipeline will provide all arguments based on the function's type hints. + + If this is a single callable, it is called directly. + Otherwise, ``bind_and_call`` will iterate over it and call all functions. + If will in either case call :meth:`Pipeline.compute` only once. + + Returns + ------- + : + The return values of the functions in the same order as the functions. + If only one function is passed, its return value + is *not* wrapped in a tuple. + """ + return_tuple = True + if callable(fns): + fns = (fns,) + return_tuple = False + + arg_types_per_function = { + fn: { + name: ty for name, ty in get_type_hints(fn).items() if name != 'return' + } + for fn in fns + } + all_arg_types = tuple( + set(chain(*(a.values() for a in arg_types_per_function.values()))) + ) + values_per_type = self.compute(all_arg_types) + results = tuple( + fn(**{name: values_per_type[ty] for name, ty in arg_types.items()}) + for fn, arg_types in arg_types_per_function.items() + ) + if not return_tuple: + return results[0] + return results diff --git a/tests/pipeline_test.py b/tests/pipeline_test.py index 7145a39e..278be7cb 100644 --- a/tests/pipeline_test.py +++ b/tests/pipeline_test.py @@ -619,3 +619,134 @@ def test_get_with_NaiveScheduler() -> None: pipeline = sl.Pipeline([int_to_float, make_int]) task = pipeline.get(float, scheduler=sl.scheduler.NaiveScheduler()) assert task.compute() == 1.5 + + +def test_bind_and_call_no_function() -> None: + pipeline = sl.Pipeline([make_int]) + assert pipeline.bind_and_call(()) == () + + +def test_bind_and_call_function_without_args() -> None: + def func() -> str: + return "func" + + pipeline = sl.Pipeline([make_int]) + assert pipeline.bind_and_call(func) == "func" + + +def test_bind_and_call_function_with_1_arg() -> None: + def func(i: int) -> int: + return i * 2 + + pipeline = sl.Pipeline([make_int]) + assert pipeline.bind_and_call(func) == 6 + + +def test_bind_and_call_function_with_2_arg2() -> None: + def func(i: int, f: float) -> float: + return i + f + + pipeline = sl.Pipeline([make_int, int_to_float]) + assert pipeline.bind_and_call(func) == 4.5 + + +def test_bind_and_call_overrides_default_args() -> None: + def func(i: int, f: float = -0.5) -> float: + return i + f + + pipeline = sl.Pipeline([make_int, int_to_float]) + assert pipeline.bind_and_call(func) == 4.5 + + +def test_bind_and_call_function_in_iterator() -> None: + def func(i: int) -> int: + return i * 2 + + pipeline = sl.Pipeline([make_int]) + assert pipeline.bind_and_call(iter((func,))) == (6,) + + +def test_bind_and_call_dataclass_without_args() -> None: + @dataclass + class C: + ... + + pipeline = sl.Pipeline([make_int]) + assert pipeline.bind_and_call(C) == C() + + +def test_bind_and_call_dataclass_with_1_arg() -> None: + @dataclass + class C: + i: int + + pipeline = sl.Pipeline([make_int]) + assert pipeline.bind_and_call(C) == C(i=3) + + +def test_bind_and_call_dataclass_with_2_arg2() -> None: + @dataclass + class C: + i: int + f: float + + pipeline = sl.Pipeline([make_int, int_to_float]) + assert pipeline.bind_and_call(C) == C(i=3, f=1.5) + + +def test_bind_and_call_two_functions() -> None: + def func1(i: int) -> int: + return 2 * i + + def func2(f: float) -> float: + return f + 1 + + pipeline = sl.Pipeline([make_int, int_to_float]) + assert pipeline.bind_and_call((func1, func2)) == (6, 2.5) + + +def test_bind_and_call_two_functions_in_iterator() -> None: + def func1(i: int) -> int: + return 2 * i + + def func2(f: float) -> float: + return f + 1 + + pipeline = sl.Pipeline([make_int, int_to_float]) + assert pipeline.bind_and_call(iter((func1, func2))) == (6, 2.5) + + +def test_bind_and_call_function_and_dataclass() -> None: + def func(i: int) -> int: + return 2 * i + + @dataclass + class C: + i: int + f: float + + pipeline = sl.Pipeline([make_int, int_to_float]) + assert pipeline.bind_and_call((func, C)) == (6, C(i=3, f=1.5)) + + +def test_bind_and_call_function_without_return_annotation() -> None: + def func(i: int): # type: ignore[no-untyped-def] + return 2 * i + + pipeline = sl.Pipeline([make_int]) + assert pipeline.bind_and_call(func) == 6 + + +def test_bind_and_call_generic_function() -> None: + T = TypeVar('T') + A = NewType('A', int) + B = NewType('B', int) + + class G(sl.Scope[T, int], int): + ... + + def func(a: G[A]) -> int: + return -4 * a + + pipeline = sl.Pipeline([], params={G[A]: 3, G[B]: 4}) + assert pipeline.bind_and_call(func) == -12