Skip to content

Commit

Permalink
Merge pull request #49 from scipp/inject-into-function
Browse files Browse the repository at this point in the history
Add Pipeline.bind_and_call
  • Loading branch information
jl-wynen authored Sep 1, 2023
2 parents 13b079b + effd1af commit 478ba64
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/sciline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
except importlib.metadata.PackageNotFoundError:
__version__ = "0.0.0"

from . import scheduler
from .domain import Scope
from .param_table import ParamTable
from .pipeline import (
Expand All @@ -27,4 +28,5 @@
"Scope",
"UnboundTypeVar",
"UnsatisfiedRequirement",
"scheduler",
]
55 changes: 55 additions & 0 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from collections import defaultdict
from itertools import chain
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -632,3 +633,57 @@ def get(
else:
graph = self.build(keys)
return TaskGraph(graph=graph, keys=keys, scheduler=scheduler)

@overload
def bind_and_call(self, fns: Callable[..., T], /) -> T:
...

@overload
def bind_and_call(self, fns: Iterable[Callable[..., Any]], /) -> Tuple[Any, ...]:
...

def bind_and_call(
self, fns: Union[Callable[..., Any], Iterable[Callable[..., Any]]], /
) -> Any:
"""
Call the given functions with arguments provided by the pipeline.
Parameters
----------
fns:
Functions to call.
The pipeline will provide all arguments based on the function's type hints.
If this is a single callable, it is called directly.
Otherwise, ``bind_and_call`` will iterate over it and call all functions.
If will in either case call :meth:`Pipeline.compute` only once.
Returns
-------
:
The return values of the functions in the same order as the functions.
If only one function is passed, its return value
is *not* wrapped in a tuple.
"""
return_tuple = True
if callable(fns):
fns = (fns,)
return_tuple = False

arg_types_per_function = {
fn: {
name: ty for name, ty in get_type_hints(fn).items() if name != 'return'
}
for fn in fns
}
all_arg_types = tuple(
set(chain(*(a.values() for a in arg_types_per_function.values())))
)
values_per_type = self.compute(all_arg_types)
results = tuple(
fn(**{name: values_per_type[ty] for name, ty in arg_types.items()})
for fn, arg_types in arg_types_per_function.items()
)
if not return_tuple:
return results[0]
return results
131 changes: 131 additions & 0 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,134 @@ def test_get_with_NaiveScheduler() -> None:
pipeline = sl.Pipeline([int_to_float, make_int])
task = pipeline.get(float, scheduler=sl.scheduler.NaiveScheduler())
assert task.compute() == 1.5


def test_bind_and_call_no_function() -> None:
pipeline = sl.Pipeline([make_int])
assert pipeline.bind_and_call(()) == ()


def test_bind_and_call_function_without_args() -> None:
def func() -> str:
return "func"

pipeline = sl.Pipeline([make_int])
assert pipeline.bind_and_call(func) == "func"


def test_bind_and_call_function_with_1_arg() -> None:
def func(i: int) -> int:
return i * 2

pipeline = sl.Pipeline([make_int])
assert pipeline.bind_and_call(func) == 6


def test_bind_and_call_function_with_2_arg2() -> None:
def func(i: int, f: float) -> float:
return i + f

pipeline = sl.Pipeline([make_int, int_to_float])
assert pipeline.bind_and_call(func) == 4.5


def test_bind_and_call_overrides_default_args() -> None:
def func(i: int, f: float = -0.5) -> float:
return i + f

pipeline = sl.Pipeline([make_int, int_to_float])
assert pipeline.bind_and_call(func) == 4.5


def test_bind_and_call_function_in_iterator() -> None:
def func(i: int) -> int:
return i * 2

pipeline = sl.Pipeline([make_int])
assert pipeline.bind_and_call(iter((func,))) == (6,)


def test_bind_and_call_dataclass_without_args() -> None:
@dataclass
class C:
...

pipeline = sl.Pipeline([make_int])
assert pipeline.bind_and_call(C) == C()


def test_bind_and_call_dataclass_with_1_arg() -> None:
@dataclass
class C:
i: int

pipeline = sl.Pipeline([make_int])
assert pipeline.bind_and_call(C) == C(i=3)


def test_bind_and_call_dataclass_with_2_arg2() -> None:
@dataclass
class C:
i: int
f: float

pipeline = sl.Pipeline([make_int, int_to_float])
assert pipeline.bind_and_call(C) == C(i=3, f=1.5)


def test_bind_and_call_two_functions() -> None:
def func1(i: int) -> int:
return 2 * i

def func2(f: float) -> float:
return f + 1

pipeline = sl.Pipeline([make_int, int_to_float])
assert pipeline.bind_and_call((func1, func2)) == (6, 2.5)


def test_bind_and_call_two_functions_in_iterator() -> None:
def func1(i: int) -> int:
return 2 * i

def func2(f: float) -> float:
return f + 1

pipeline = sl.Pipeline([make_int, int_to_float])
assert pipeline.bind_and_call(iter((func1, func2))) == (6, 2.5)


def test_bind_and_call_function_and_dataclass() -> None:
def func(i: int) -> int:
return 2 * i

@dataclass
class C:
i: int
f: float

pipeline = sl.Pipeline([make_int, int_to_float])
assert pipeline.bind_and_call((func, C)) == (6, C(i=3, f=1.5))


def test_bind_and_call_function_without_return_annotation() -> None:
def func(i: int): # type: ignore[no-untyped-def]
return 2 * i

pipeline = sl.Pipeline([make_int])
assert pipeline.bind_and_call(func) == 6


def test_bind_and_call_generic_function() -> None:
T = TypeVar('T')
A = NewType('A', int)
B = NewType('B', int)

class G(sl.Scope[T, int], int):
...

def func(a: G[A]) -> int:
return -4 * a

pipeline = sl.Pipeline([], params={G[A]: 3, G[B]: 4})
assert pipeline.bind_and_call(func) == -12

0 comments on commit 478ba64

Please sign in to comment.