Skip to content

Commit

Permalink
Use error handler when building series subgraph
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonHeybrock committed Dec 5, 2023
1 parent f5f0491 commit 8b8c4e0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,8 @@ def build(
graph[tp] = (_param_sentinel, (self._param_name_to_table_key[tp],))
continue
if get_origin(tp) == Series:
graph.update(self._build_series(tp)) # type: ignore[arg-type]
sub = self._build_series(tp, handler=handler) # type: ignore[arg-type]
graph.update(sub)
continue
if (optional_arg := get_optional(tp)) is not None:
try:
Expand Down Expand Up @@ -544,7 +545,9 @@ def build(
stack.append(arg)
return graph

def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph:
def _build_series(
self, tp: Type[Series[KeyType, ValueType]], handler: ErrorHandler
) -> Graph:
"""
Build (sub)graph for a Series type implementing ParamTable-based functionality.
Expand Down Expand Up @@ -604,9 +607,7 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph:
value_type: Type[ValueType]
index_name, value_type = get_args(tp)

subgraph = self.build(
value_type, search_param_tables=True, handler=HandleAsBuildTimeException()
)
subgraph = self.build(value_type, search_param_tables=True, handler=handler)

replicator: ReplicatorBase[KeyType]
if (
Expand Down
11 changes: 11 additions & 0 deletions tests/pipeline_with_param_table_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,3 +594,14 @@ def parametrized_gather(x: sl.Series[Row, Param[T]]) -> Str[T]:

assert pipeline.compute(Str[int]) == Str[int]('1,2')
assert pipeline.compute(Str[float]) == Str[float]('1.5,2.5')


def test_compute_time_handler_works_alongside_param_table() -> None:
Missing = NewType("Missing", str)

def process(x: float, missing: Missing) -> str:
return str(x) + missing

pl = sl.Pipeline([process])
pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
pl.get(sl.Series[int, str], handler=sl.HandleAsComputeTimeException())

0 comments on commit 8b8c4e0

Please sign in to comment.