Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pipeline.bind_and_call #49

Merged
merged 4 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/sciline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
except importlib.metadata.PackageNotFoundError:
__version__ = "0.0.0"

from . import scheduler
from .domain import Scope
from .param_table import ParamTable
from .pipeline import (
Expand All @@ -27,4 +28,5 @@
"Scope",
"UnboundTypeVar",
"UnsatisfiedRequirement",
"scheduler",
]
55 changes: 55 additions & 0 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from collections import defaultdict
from itertools import chain
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -632,3 +633,57 @@ def get(
else:
graph = self.build(keys)
return TaskGraph(graph=graph, keys=keys, scheduler=scheduler)

@overload
def bind_and_call(self, fns: Callable[..., T], /) -> T:
...

@overload
def bind_and_call(self, fns: Iterable[Callable[..., Any]], /) -> Tuple[Any, ...]:
...

def bind_and_call(
self, fns: Union[Callable[..., Any], Iterable[Callable[..., Any]]], /
) -> Any:
"""
Call the given functions with arguments provided by the pipeline.

Parameters
----------
fns:
Functions to call.
The pipeline will provide all arguments based on the function's type hints.

If this is a single callable, it is called directly.
Otherwise, ``bind_and_call`` will iterate over it and call all functions.
If will in either case call :meth:`Pipeline.compute` only once.

Returns
-------
:
The return values of the functions in the same order as the functions.
If only one function is passed, its return value
is *not* wrapped in a tuple.
"""
return_tuple = True
if callable(fns):
fns = (fns,)
return_tuple = False

arg_types_per_function = {
fn: {
name: ty for name, ty in get_type_hints(fn).items() if name != 'return'
}
for fn in fns
}
all_arg_types = tuple(
set(chain(*(a.values() for a in arg_types_per_function.values())))
)
values_per_type = self.compute(all_arg_types)
results = tuple(
fn(**{name: values_per_type[ty] for name, ty in arg_types.items()})
for fn, arg_types in arg_types_per_function.items()
)
if not return_tuple:
return results[0]
return results
131 changes: 131 additions & 0 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,134 @@ def test_get_with_NaiveScheduler() -> None:
pipeline = sl.Pipeline([int_to_float, make_int])
task = pipeline.get(float, scheduler=sl.scheduler.NaiveScheduler())
assert task.compute() == 1.5


def test_bind_and_call_no_function() -> None:
pipeline = sl.Pipeline([make_int])
assert pipeline.bind_and_call(()) == ()


def test_bind_and_call_function_without_args() -> None:
def func() -> str:
return "func"

pipeline = sl.Pipeline([make_int])
assert pipeline.bind_and_call(func) == "func"


def test_bind_and_call_function_with_1_arg() -> None:
def func(i: int) -> int:
return i * 2

pipeline = sl.Pipeline([make_int])
assert pipeline.bind_and_call(func) == 6


def test_bind_and_call_function_with_2_arg2() -> None:
def func(i: int, f: float) -> float:
return i + f

pipeline = sl.Pipeline([make_int, int_to_float])
assert pipeline.bind_and_call(func) == 4.5


def test_bind_and_call_overrides_default_args() -> None:
def func(i: int, f: float = -0.5) -> float:
return i + f

pipeline = sl.Pipeline([make_int, int_to_float])
assert pipeline.bind_and_call(func) == 4.5


def test_bind_and_call_function_in_iterator() -> None:
def func(i: int) -> int:
return i * 2

pipeline = sl.Pipeline([make_int])
assert pipeline.bind_and_call(iter((func,))) == (6,)


def test_bind_and_call_dataclass_without_args() -> None:
@dataclass
class C:
...

pipeline = sl.Pipeline([make_int])
assert pipeline.bind_and_call(C) == C()


def test_bind_and_call_dataclass_with_1_arg() -> None:
@dataclass
class C:
i: int

pipeline = sl.Pipeline([make_int])
assert pipeline.bind_and_call(C) == C(i=3)


def test_bind_and_call_dataclass_with_2_arg2() -> None:
@dataclass
class C:
i: int
f: float

pipeline = sl.Pipeline([make_int, int_to_float])
assert pipeline.bind_and_call(C) == C(i=3, f=1.5)


def test_bind_and_call_two_functions() -> None:
def func1(i: int) -> int:
return 2 * i

def func2(f: float) -> float:
return f + 1

pipeline = sl.Pipeline([make_int, int_to_float])
assert pipeline.bind_and_call((func1, func2)) == (6, 2.5)


def test_bind_and_call_two_functions_in_iterator() -> None:
def func1(i: int) -> int:
return 2 * i

def func2(f: float) -> float:
return f + 1

pipeline = sl.Pipeline([make_int, int_to_float])
assert pipeline.bind_and_call(iter((func1, func2))) == (6, 2.5)


def test_bind_and_call_function_and_dataclass() -> None:
def func(i: int) -> int:
return 2 * i

@dataclass
class C:
i: int
f: float

pipeline = sl.Pipeline([make_int, int_to_float])
assert pipeline.bind_and_call((func, C)) == (6, C(i=3, f=1.5))


def test_bind_and_call_function_without_return_annotation() -> None:
def func(i: int): # type: ignore[no-untyped-def]
return 2 * i

pipeline = sl.Pipeline([make_int])
assert pipeline.bind_and_call(func) == 6


def test_bind_and_call_generic_function() -> None:
T = TypeVar('T')
A = NewType('A', int)
B = NewType('B', int)

class G(sl.Scope[T, int], int):
...

def func(a: G[A]) -> int:
return -4 * a

pipeline = sl.Pipeline([], params={G[A]: 3, G[B]: 4})
assert pipeline.bind_and_call(func) == -12