Skip to content

Commit

Permalink
Merge pull request #178 from scipp/provider-eq
Browse files Browse the repository at this point in the history
Add Provider.__eq__
  • Loading branch information
jl-wynen authored Aug 20, 2024
2 parents 9521133 + 1751fd1 commit 0c193cc
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 30 deletions.
24 changes: 9 additions & 15 deletions src/sciline/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
ToProvider = Callable[..., Any]
"""Callable that can be converted to a provider."""

ProviderKind = Literal[
'function', 'parameter', 'series', 'table_cell', 'sentinel', 'unsatisfied'
]
ProviderKind = Literal['function', 'parameter', 'unsatisfied']
"""Identifies the kind of a provider, most are used internally."""


Expand Down Expand Up @@ -78,18 +76,6 @@ def parameter(cls, param: Any) -> Provider:
),
)

@classmethod
def table_cell(cls, param: Any) -> Provider:
"""Construct a provider that returns the label for a table row."""
return cls(
func=lambda: param,
arg_spec=ArgSpec.null(),
kind='table_cell',
location=ProviderLocation(
name=f'table_cell({type(param).__name__})', module=_module_name(param)
),
)

@classmethod
def provide_none(cls) -> Provider:
"""Provider that takes no arguments and returns None."""
Expand Down Expand Up @@ -136,6 +122,14 @@ def bind_type_vars(self, bound: dict[TypeVar, Key]) -> Provider:
kind=self._kind,
)

def __eq__(self, other: object) -> bool:
if isinstance(other, Provider):
return self._func == other._func
return NotImplemented

def __ne__(self, other: object) -> bool:
return not (self == other)

def __str__(self) -> str:
return f"Provider('{self.location.name}')"

Expand Down
2 changes: 1 addition & 1 deletion src/sciline/data_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __setitem__(self, key: Key, value: DataGraph | Any) -> None:
----------
key:
Type to provide a value for.
param:
value:
Concrete value to provide.
"""
# This is a questionable approach: Using MyGeneric[T] as a key will actually
Expand Down
15 changes: 2 additions & 13 deletions src/sciline/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,9 @@ def to_graphviz(


def _to_subgraphs(graph: FormattedGraph) -> dict[str, FormattedGraph]:
def get_subgraph_name(name: str, kind: str) -> str:
if kind == 'series':
# Example: Series[RowId, Material[Country]] -> RowId, Material[Country]
return name.partition('[')[-1].rpartition(']')[0]
return name.split('[')[0]

subgraphs: dict[str, FormattedGraph] = {}
for p, formatted_p in graph.items():
subgraph_name = get_subgraph_name(formatted_p.ret.name, formatted_p.kind)
subgraph_name = formatted_p.ret.name.split('[')[0]
subgraphs.setdefault(subgraph_name, {})
subgraphs[subgraph_name][p] = formatted_p
return subgraphs
Expand All @@ -105,12 +99,7 @@ def _add_subgraph(graph: FormattedGraph, dot: Digraph, subgraph: Digraph) -> Non
formatted_p.ret.name,
shape='box3d' if formatted_p.ret.collapsed else 'rectangle',
)
# Do not draw the internal provider gathering index-dependent results into
# a dict
if formatted_p.kind == 'series':
for arg in formatted_p.args:
dot.edge(arg.name, formatted_p.ret.name, style='dashed')
elif formatted_p.kind == 'function':
if formatted_p.kind == 'function':
dot.node(p, formatted_p.name, shape='ellipse')
for arg in formatted_p.args:
dot.edge(arg.name, p)
Expand Down
48 changes: 47 additions & 1 deletion tests/pipeline_setitem_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from typing import NewType
from typing import NewType, TypeVar

import pytest

Expand All @@ -11,6 +11,9 @@
C = NewType('C', int)
D = NewType('D', int)
X = NewType('X', int)
Y = NewType('Y', int)

P = TypeVar('P', X, Y)


def a_to_b(a: A) -> B:
Expand All @@ -37,6 +40,38 @@ def test_setitem_can_compose_pipelines() -> None:
assert bc.compute(C) == C(3)


def test_setitem_with_common_function_provider() -> None:
def a() -> A:
return A(2)

abc = sl.Pipeline((a, ab_to_c))
abc[B] = sl.Pipeline((a, a_to_b))
assert abc.compute(C) == C(5)


def test_setitem_with_common_parameter_provider() -> None:
abc = sl.Pipeline((ab_to_c,), params={A: A(3)})
abc[B] = sl.Pipeline((a_to_b,), params={A: A(3)})
assert abc.compute(C) == C(7)


def test_setitem_with_generic_providers() -> None:
class GA(sl.Scope[P, int], int): ...

class GB(sl.Scope[P, int], int): ...

def ga_to_gb(ga: GA[P]) -> GB[P]:
return GB[P](ga + 1)

def gb_to_c(gbx: GB[X], gby: GB[Y]) -> C:
return C(gbx + gby)

abc = sl.Pipeline((gb_to_c,), params={GA[X]: 3, GA[Y]: 4})
abc[GB[X]] = sl.Pipeline((ga_to_gb,), params={GA[X]: 3})[GB[X]]
abc[GB[Y]] = sl.Pipeline((ga_to_gb,), params={GA[Y]: 4})[GB[Y]]
assert abc.compute(C) == C(9)


def test_setitem_raises_if_value_pipeline_has_no_unique_output() -> None:
abx = sl.Pipeline((a_to_b,))
abx[X] = 666
Expand Down Expand Up @@ -85,6 +120,17 @@ def test_setitem_with_conflicting_nodes_in_value_pipeline_raises_on_data_mismatc
abc[B] = ab


def test_setitem_with_conflicting_node_types_pipeline_raises_on_data_mismatch() -> None:
def a() -> A:
return A(100)

ab = sl.Pipeline((a, a_to_b))
abc = sl.Pipeline((ab_to_c,))
abc[A] = 666
with pytest.raises(ValueError, match="Node data differs"):
abc[B] = ab


def test_setitem_with_conflicting_nodes_in_value_pipeline_accepts_on_data_match() -> (
None
):
Expand Down

0 comments on commit 0c193cc

Please sign in to comment.