Skip to content

Commit

Permalink
Merge pull request #77 from scipp/specialized-providers
Browse files Browse the repository at this point in the history
feat: prioritize specialized providers over generic
  • Loading branch information
jokasimr authored Nov 28, 2023
2 parents f04f65e + fb43e78 commit 85f1251
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,16 @@ def _get_provider(
for args, subprovider in subproviders.items()
if _is_compatible_type_tuple(requested, args)
]
typevar_counts = [
sum(1 for t in args if isinstance(t, TypeVar)) for args, _ in matches
]
min_typevar_count = min(typevar_counts, default=0)
matches = [
m
for count, m in zip(typevar_counts, matches)
if count == min_typevar_count
]

if len(matches) == 1:
args, provider = matches[0]
bound = {
Expand Down
78 changes: 78 additions & 0 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,3 +829,81 @@ def d(_f: float) -> None:
assert calls.index('b') in (0, 1)
assert calls.index('c') in (2, 3)
assert calls.index('d') in (2, 3)


def test_prioritizes_specialized_provider_over_generic() -> None:
A = NewType('A', str)
B = NewType('B', str)
V = TypeVar('V', A, B)

class H(sl.Scope[V, str], str):
pass

def p1(x: V) -> H[V]:
return H[V]("Generic")

def p2(x: B) -> H[B]:
return H[B]("Special")

pl = sl.Pipeline([p1, p2], params={A: 'A', B: 'B'})

assert str(pl.compute(H[A])) == "Generic"
assert str(pl.compute(H[B])) == "Special"


def test_prioritizes_specialized_provider_over_generic_several_typevars() -> None:
A = NewType('A', str)
B = NewType('B', str)
T1 = TypeVar('T1')
T2 = TypeVar('T2')

@dataclass
class C(Generic[T1, T2]):
first: T1
second: T2
third: str

def p1(x: T1, y: T2) -> C[T1, T2]:
return C(x, y, 'generic')

def p2(x: A, y: T2) -> C[A, T2]:
return C(x, y, 'medium generic')

def p3(x: T2, y: B) -> C[T2, B]:
return C(x, y, 'generic medium')

def p4(x: A, y: B) -> C[A, B]:
return C(x, y, 'special')

pl = sl.Pipeline([p1, p2, p3, p4], params={A: A('A'), B: B('B')})

assert pl.compute(C[B, A]) == C('B', 'A', 'generic')
assert pl.compute(C[A, A]) == C('A', 'A', 'medium generic')
assert pl.compute(C[B, B]) == C('B', 'B', 'generic medium')
assert pl.compute(C[A, B]) == C('A', 'B', 'special')


def test_prioritizes_specialized_provider_raises() -> None:
A = NewType('A', str)
B = NewType('B', str)
T1 = TypeVar('T1')
T2 = TypeVar('T2')

@dataclass
class C(Generic[T1, T2]):
first: T1
second: T2

def p1(x: A, y: T1) -> C[A, T1]:
return C(x, y)

def p2(x: T1, y: B) -> C[T1, B]:
return C(x, y)

pl = sl.Pipeline([p1, p2], params={A: A('A'), B: B('B')})

with pytest.raises(sl.AmbiguousProvider):
pl.compute(C[A, B])

with pytest.raises(sl.UnsatisfiedRequirement):
pl.compute(C[B, A])

0 comments on commit 85f1251

Please sign in to comment.