diff --git a/docs/api-reference/index.md b/docs/api-reference/index.md
index 883ecd98..0cba8d9c 100644
--- a/docs/api-reference/index.md
+++ b/docs/api-reference/index.md
@@ -10,8 +10,10 @@
:template: class-template.rst
:recursive:
+ ParamTable
Pipeline
Scope
+ Series
scheduler.Scheduler
scheduler.DaskScheduler
scheduler.NaiveScheduler
diff --git a/docs/user-guide/getting-started.ipynb b/docs/user-guide/getting-started.ipynb
index f9d08cc4..981f0d2a 100644
--- a/docs/user-guide/getting-started.ipynb
+++ b/docs/user-guide/getting-started.ipynb
@@ -101,6 +101,9 @@
"from typing import NewType\n",
"import sciline\n",
"\n",
+ "_fake_filesytem = {'data.txt': [1, 2, float('nan'), 3]}\n",
+ "\n",
+ "\n",
"# 1. Define domain types\n",
"\n",
"Filename = NewType('Filename', str)\n",
@@ -115,19 +118,20 @@
"\n",
"def load(filename: Filename) -> RawData:\n",
" \"\"\"Load the data from the filename.\"\"\"\n",
- " return {'data': [1, 2, float('nan'), 3], 'meta': {'filename': filename}}\n",
+ " data = _fake_filesytem[filename]\n",
+ " return RawData({'data': data, 'meta': {'filename': filename}})\n",
"\n",
"\n",
"def clean(raw_data: RawData) -> CleanedData:\n",
" \"\"\"Clean the data, removing NaNs.\"\"\"\n",
" import math\n",
"\n",
- " return [x for x in raw_data['data'] if not math.isnan(x)]\n",
+ " return CleanedData([x for x in raw_data['data'] if not math.isnan(x)])\n",
"\n",
"\n",
"def process(data: CleanedData, param: ScaleFactor) -> Result:\n",
" \"\"\"Process the data, multiplying the sum by the scale factor.\"\"\"\n",
- " return sum(data) * param\n",
+ " return Result(sum(data) * param)\n",
"\n",
"\n",
"# 3. Create pipeline\n",
diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md
index ca13f1a2..3baa2628 100644
--- a/docs/user-guide/index.md
+++ b/docs/user-guide/index.md
@@ -6,4 +6,5 @@ maxdepth: 2
---
getting-started
+parameter-tables
```
diff --git a/docs/user-guide/parameter-tables.ipynb b/docs/user-guide/parameter-tables.ipynb
new file mode 100644
index 00000000..8ed44e3b
--- /dev/null
+++ b/docs/user-guide/parameter-tables.ipynb
@@ -0,0 +1,486 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Parameter Tables\n",
+ "\n",
+ "## Overview\n",
+ "\n",
+ "Parameter tables provide a mechanism for repeating parts of or all of a computation with different values for one or more parameters.\n",
+ "This allows for a variety of use cases, similar to *map*, *reduce*, and *groupby* operations in other systems.\n",
+ "We illustrate each of these in the follow three chapters."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Computing results for series of parameters\n",
+ "\n",
+ "This chapter illustrates how to implement *map* operations with Sciline.\n",
+ "\n",
+ "Starting with the model workflow introduced in [Getting Started](getting-started.ipynb), we can replace the fixed `Filename` parameter with a series of filenames listed in a [ParamTable](../generated/classes/ParamTable.rst):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from typing import NewType\n",
+ "import sciline\n",
+ "\n",
+ "_fake_filesytem = {\n",
+ " 'file102.txt': [1, 2, float('nan'), 3],\n",
+ " 'file103.txt': [1, 2, 3, 4],\n",
+ " 'file104.txt': [1, 2, 3, 4, 5],\n",
+ " 'file105.txt': [1, 2, 3],\n",
+ "}\n",
+ "\n",
+ "# 1. Define domain types\n",
+ "\n",
+ "Filename = NewType('Filename', str)\n",
+ "RawData = NewType('RawData', dict)\n",
+ "CleanedData = NewType('CleanedData', list)\n",
+ "ScaleFactor = NewType('ScaleFactor', float)\n",
+ "Result = NewType('Result', float)\n",
+ "\n",
+ "\n",
+ "# 2. Define providers\n",
+ "\n",
+ "\n",
+ "def load(filename: Filename) -> RawData:\n",
+ " \"\"\"Load the data from the filename.\"\"\"\n",
+ "\n",
+ " data = _fake_filesytem[filename]\n",
+ " return RawData({'data': data, 'meta': {'filename': filename}})\n",
+ "\n",
+ "\n",
+ "def clean(raw_data: RawData) -> CleanedData:\n",
+ " \"\"\"Clean the data, removing NaNs.\"\"\"\n",
+ " import math\n",
+ "\n",
+ " return CleanedData([x for x in raw_data['data'] if not math.isnan(x)])\n",
+ "\n",
+ "\n",
+ "def process(data: CleanedData, param: ScaleFactor) -> Result:\n",
+ " \"\"\"Process the data, multiplying the sum by the scale factor.\"\"\"\n",
+ " return Result(sum(data) * param)\n",
+ "\n",
+ "\n",
+ "# 3. Create pipeline\n",
+ "\n",
+ "# 3.a Providers and normal parameters\n",
+ "providers = [load, clean, process]\n",
+ "params = {ScaleFactor: 2.0}\n",
+ "\n",
+ "# 3.b Parameter table\n",
+ "RunID = NewType('RunID', int)\n",
+ "run_ids = [102, 103, 104, 105]\n",
+ "filenames = [f'file{i}.txt' for i in run_ids]\n",
+ "param_table = sciline.ParamTable(RunID, {Filename: filenames}, index=run_ids)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Note how steps 1.) and 2.) are identical to those from the example without parameter table.\n",
+ "Above we have created the following parameter table:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "param_table"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can now create the pipeline and set the parameter table:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 3.c Setup pipeline\n",
+ "pipeline = sciline.Pipeline(providers, params=params)\n",
+ "pipeline.set_param_table(param_table)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Then we can compute `Result` for each index in the parameter table:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipeline.compute(sciline.Series[RunID, Result])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "`sciline.Series` is a special `dict`-like type that signals to Sciline that the values of the series are based on values from one or more columns of a parameter table.\n",
+ "The parameter table is identified using the first argument to `Series`, in this case `RunID`.\n",
+ "The second argument specifies the result to be computed.\n",
+ "\n",
+ "We can also visualize the task graph for computing the series of `Result` values:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipeline.visualize(sciline.Series[RunID, Result])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Nodes that depend on values from a parameter table are drawn as \"3D boxes\" with the parameter index name (the row dimension of the parameter table) and value given in parenthesis.\n",
+ "The dashed arrow indicates and internal transformation that gathers result from each branch and combines them into a single output, here `Series[RunID, Result]`.\n",
+ "\n",
+ "
\n",
+ "\n",
+ "Note\n",
+ "\n",
+ "With long parameter tables, graphs can get messy and hard to read.\n",
+ "Try using `visualize(..., compact=True)`.\n",
+ "\n",
+ "The `compact=True` option to yields a much more compact representation.\n",
+ "Instead of drawing every intermediate result and provider for each parameter, we then represent each parameter-dependent result as a single \"3D box\" node, representing all nodes for different values of the respective parameter.\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Combining intermediate results from series of parameters\n",
+ "\n",
+ "This chapter illustrates how to implement *reduce* operations with Sciline.\n",
+ "\n",
+ "Instead of requesting a series of results as above, we can also build pipelines with providers that depend on such series.\n",
+ "We can create a new pipeline, or extend the existing one by inserting a new provider:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MergedResult = NewType('MergedResult', float)\n",
+ "\n",
+ "\n",
+ "def merge_runs(runs: sciline.Series[RunID, Result]) -> MergedResult:\n",
+ " return MergedResult(sum(runs.values()))\n",
+ "\n",
+ "\n",
+ "pipeline.insert(merge_runs)\n",
+ "graph = pipeline.get(MergedResult)\n",
+ "graph.visualize()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Note that this is identical to the example in the previous section, except for the last two nodes in the graph.\n",
+ "The computation now returns a single result:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "graph.compute()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This is useful if we need to continue computation after gathering results without setting up a second pipeline."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Grouping intermediate results based on secondary parameters\n",
+ "\n",
+ "This chapter illustrates how to implement *groupby* operations with Sciline.\n",
+ "\n",
+ "Continuing from the examples for *map* and *reduce*, we can introduce a secondary parameter in the table, such as the material of the sample:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "Material = NewType('Material', str)\n",
+ "\n",
+ "# 3.a Providers and normal parameters\n",
+ "providers = [load, clean, process, merge_runs]\n",
+ "params = {ScaleFactor: 2.0}\n",
+ "\n",
+ "# 3.b Parameter table\n",
+ "run_ids = [102, 103, 104, 105]\n",
+ "sample = ['diamond', 'graphite', 'graphite', 'graphite']\n",
+ "filenames = [f'file{i}.txt' for i in run_ids]\n",
+ "param_table = sciline.ParamTable(\n",
+ " RunID, {Filename: filenames, Material: sample}, index=run_ids\n",
+ ")\n",
+ "param_table"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 3.c Setup pipeline\n",
+ "pipeline = sciline.Pipeline(providers, params=params)\n",
+ "pipeline.set_param_table(param_table)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can now compute `MergedResult` for a series of \"materials\":"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipeline.compute(sciline.Series[Material, MergedResult])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The computation looks as show below.\n",
+ "Note how the initial steps of the computation depend on the `RunID` parameter, while later steps depend on `Material`:\n",
+ "The files for each run ID have been grouped by their material and then merged:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipeline.visualize(sciline.Series[Material, MergedResult])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## More examples"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Combining multiple parameters from same table"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sciline as sl\n",
+ "\n",
+ "Sum = NewType(\"Sum\", float)\n",
+ "Param1 = NewType(\"Param1\", int)\n",
+ "Param2 = NewType(\"Param2\", int)\n",
+ "Row = NewType(\"Run\", int)\n",
+ "\n",
+ "\n",
+ "def gather(\n",
+ " x: sl.Series[Row, float],\n",
+ ") -> Sum:\n",
+ " return Sum(sum(x.values()))\n",
+ "\n",
+ "\n",
+ "def product(x: Param1, y: Param2) -> float:\n",
+ " return x / y\n",
+ "\n",
+ "\n",
+ "pl = sl.Pipeline([gather, product])\n",
+ "pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 4, 9], Param2: [1, 2, 3]}))\n",
+ "\n",
+ "pl.visualize(Sum)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pl.compute(Sum)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Diamond graphs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "Sum = NewType(\"Sum\", float)\n",
+ "Param = NewType(\"Param\", int)\n",
+ "Param1 = NewType(\"Param1\", int)\n",
+ "Param2 = NewType(\"Param2\", int)\n",
+ "Row = NewType(\"Run\", int)\n",
+ "\n",
+ "\n",
+ "def gather(x: sl.Series[Row, float]) -> Sum:\n",
+ " return Sum(sum(x.values()))\n",
+ "\n",
+ "\n",
+ "def to_param1(x: Param) -> Param1:\n",
+ " return Param1(x)\n",
+ "\n",
+ "\n",
+ "def to_param2(x: Param) -> Param2:\n",
+ " return Param2(x)\n",
+ "\n",
+ "\n",
+ "def product(x: Param1, y: Param2) -> float:\n",
+ " return x * y\n",
+ "\n",
+ "\n",
+ "pl = sl.Pipeline([gather, product, to_param1, to_param2])\n",
+ "pl.set_param_table(sl.ParamTable(Row, {Param: [1, 2, 3]}))\n",
+ "pl.visualize(Sum)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Combining parameters from different tables"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sciline as sl\n",
+ "\n",
+ "List1 = NewType(\"List1\", float)\n",
+ "List2 = NewType(\"List2\", float)\n",
+ "Param1 = NewType(\"Param1\", int)\n",
+ "Param2 = NewType(\"Param2\", int)\n",
+ "Row1 = NewType(\"Row1\", int)\n",
+ "Row2 = NewType(\"Row2\", int)\n",
+ "\n",
+ "\n",
+ "def gather1(x: sl.Series[Row1, float]) -> List1:\n",
+ " return List1(list(x.values()))\n",
+ "\n",
+ "\n",
+ "def gather2(x: sl.Series[Row2, List1]) -> List2:\n",
+ " return List2(list(x.values()))\n",
+ "\n",
+ "\n",
+ "def product(x: Param1, y: Param2) -> float:\n",
+ " return x * y\n",
+ "\n",
+ "\n",
+ "pl = sl.Pipeline([gather1, gather2, product])\n",
+ "pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 4, 9]}))\n",
+ "pl.set_param_table(sl.ParamTable(Row2, {Param2: [1, 2]}))\n",
+ "\n",
+ "pl.visualize(List2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Note how intermediates such as `float(Row1, Row2)` depend on two parameters, i.e., we are dealing with a 2-D array of branches in the graph."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pl.compute(List2)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py
index 2f43d450..eb74611c 100644
--- a/src/sciline/__init__.py
+++ b/src/sciline/__init__.py
@@ -10,15 +10,19 @@
__version__ = "0.0.0"
from .domain import Scope
+from .param_table import ParamTable
from .pipeline import (
AmbiguousProvider,
Pipeline,
UnboundTypeVar,
UnsatisfiedRequirement,
)
+from .series import Series
__all__ = [
"AmbiguousProvider",
+ "Series",
+ "ParamTable",
"Pipeline",
"Scope",
"UnboundTypeVar",
diff --git a/src/sciline/param_table.py b/src/sciline/param_table.py
new file mode 100644
index 00000000..585f7e8e
--- /dev/null
+++ b/src/sciline/param_table.py
@@ -0,0 +1,86 @@
+# SPDX-License-Identifier: BSD-3-Clause
+# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
+from __future__ import annotations
+
+from typing import Any, Collection, Dict, Mapping, Optional
+
+
+class ParamTable(Mapping[type, Collection[Any]]):
+ """A table of parameters with a row index and named row dimension."""
+
+ def __init__(
+ self,
+ row_dim: type,
+ columns: Dict[type, Collection[Any]],
+ *,
+ index: Optional[Collection[Any]] = None,
+ ):
+ """
+ Create a new param table.
+
+ Parameters
+ ----------
+ row_dim:
+ The row dimension. This must be a type or a type-alias (not an instance),
+ and is used by :py:class:`sciline.Pipeline` to identify each parameter
+ table.
+ columns:
+ The columns of the table. The keys (column names) must be types or type-
+ aliases matching the values in the respective columns.
+ index:
+ The row index of the table. If not given, a default index will be
+ generated, as the integer range of the column length.
+ """
+ sizes = set(len(v) for v in columns.values())
+ if len(sizes) != 1:
+ raise ValueError(
+ f"Columns in param table must all have same size, got {sizes}"
+ )
+ size = sizes.pop()
+ if index is not None:
+ if len(index) != size:
+ raise ValueError(
+ f"Index length not matching columns, got {len(index)} and {size}"
+ )
+ if len(set(index)) != len(index):
+ raise ValueError(f"Index must be unique, got {index}")
+ self._row_dim = row_dim
+ self._columns = columns
+ self._index = index or list(range(size))
+
+ @property
+ def row_dim(self) -> type:
+ """The row dimension of the table."""
+ return self._row_dim
+
+ @property
+ def index(self) -> Collection[Any]:
+ """The row index of the table."""
+ return self._index
+
+ def __contains__(self, key: Any) -> bool:
+ return self._columns.__contains__(key)
+
+ def __getitem__(self, key: Any) -> Any:
+ return self._columns.__getitem__(key)
+
+ def __iter__(self) -> Any:
+ return self._columns.__iter__()
+
+ def __len__(self) -> int:
+ return self._columns.__len__()
+
+ def __repr__(self) -> str:
+ return f"ParamTable(row_dim={self.row_dim}, columns={self._columns})"
+
+ def _repr_html_(self) -> str:
+ return (
+ f"{self.row_dim.__name__} | "
+ + "".join(f"{k.__name__} | " for k in self._columns.keys())
+ + "
"
+ + "".join(
+ f"{idx} | " + "".join(f"{v} | " for v in row) + "
"
+ for idx, row in zip(self.index, zip(*self._columns.values()))
+ )
+ + "
"
+ )
diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py
index 4dbd104f..75d8a527 100644
--- a/src/sciline/pipeline.py
+++ b/src/sciline/pipeline.py
@@ -2,15 +2,21 @@
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from __future__ import annotations
+from collections import defaultdict
from typing import (
Any,
Callable,
+ Collection,
Dict,
+ Generic,
+ Iterable,
List,
+ Mapping,
Optional,
Tuple,
Type,
TypeVar,
+ Union,
get_args,
get_origin,
get_type_hints,
@@ -20,9 +26,16 @@
from sciline.task_graph import TaskGraph
from .domain import Scope
-from .scheduler import Graph, Scheduler
+from .param_table import ParamTable
+from .scheduler import Scheduler
+from .series import Series
+from .typing import Graph, Item, Key, Label, Provider
T = TypeVar('T')
+KeyType = TypeVar('KeyType')
+ValueType = TypeVar('ValueType')
+IndexType = TypeVar('IndexType')
+LabelType = TypeVar('LabelType')
class UnsatisfiedRequirement(Exception):
@@ -39,10 +52,6 @@ class AmbiguousProvider(Exception):
"""Raised when multiple providers are found for a type."""
-Provider = Callable[..., Any]
-Key = type
-
-
def _is_compatible_type_tuple(
requested: tuple[Key, ...],
provided: tuple[Key | TypeVar, ...],
@@ -74,6 +83,195 @@ def _bind_free_typevars(tp: TypeVar | Key, bound: Dict[TypeVar, Key]) -> Key:
return tp
+def _find_all_paths(
+ dependencies: Mapping[T, Collection[T]], start: T, end: T
+) -> List[List[T]]:
+ """Find all paths from start to end in a DAG."""
+ if start == end:
+ return [[start]]
+ if start not in dependencies:
+ return []
+ paths = []
+ for node in dependencies[start]:
+ for path in _find_all_paths(dependencies, node, end):
+ paths.append([start] + path)
+ return paths
+
+
+def _find_nodes_in_paths(
+ graph: Mapping[T, Tuple[Callable[..., Any], Collection[T]]], end: T
+) -> List[T]:
+ """
+ Helper for Pipeline. Finds all nodes that need to be duplicated since they depend
+ on a value from a param table.
+ """
+ start = next(iter(graph))
+ # 0 is the provider, 1 is the args
+ dependencies = {k: v[1] for k, v in graph.items()}
+ paths = _find_all_paths(dependencies, start, end)
+ nodes = set()
+ for path in paths:
+ nodes.update(path)
+ return list(nodes)
+
+
+class ReplicatorBase(Generic[IndexType]):
+ def __init__(self, index_name: type, index: Iterable[IndexType], path: List[Key]):
+ self._index_name = index_name
+ self.index = index
+ self._path = path
+
+ def __contains__(self, key: Key) -> bool:
+ return key in self._path
+
+ def replicate(
+ self,
+ key: Key,
+ value: Any,
+ get_provider: Callable[..., Tuple[Provider, Dict[TypeVar, Key]]],
+ ) -> Graph:
+ graph: Graph = {}
+ provider, args = value
+ for idx in self.index:
+ subkey = self.key(idx, key)
+ if provider == _param_sentinel:
+ graph[subkey] = (get_provider(subkey)[0], ())
+ else:
+ graph[subkey] = self._copy_node(key, provider, args, idx)
+ return graph
+
+ def _copy_node(
+ self,
+ key: Key,
+ provider: Union[Provider, SeriesProvider[IndexType]],
+ args: Tuple[Key, ...],
+ idx: IndexType,
+ ) -> Tuple[Provider, Tuple[Key, ...]]:
+ return (
+ provider,
+ tuple(self.key(idx, arg) if arg in self else arg for arg in args),
+ )
+
+ def key(self, i: IndexType, value_name: Union[Type[T], Item[T]]) -> Item[T]:
+ label = Label(self._index_name, i)
+ if isinstance(value_name, Item):
+ return Item(value_name.label + (label,), value_name.tp)
+ else:
+ return Item((label,), value_name)
+
+
+class Replicator(ReplicatorBase[IndexType]):
+ r"""
+ Helper for rewriting the graph to map over a given index.
+
+ See Pipeline._build_series for context. Given a graph template, this makes a
+ transformation as follows:
+
+ S P1[0] P1[1] P1[2]
+ | | | |
+ A -> A[0] A[1] A[2]
+ | | | |
+ B B[0] B[1] B[2]
+
+ Where S is a sentinel value, P1 are parameters from a parameter table, 0,1,2
+ are indices of the param table rows, and A and B are arbitrary nodes in the graph.
+ """
+
+ def __init__(self, param_table: ParamTable, graph_template: Graph) -> None:
+ index_name = param_table.row_dim
+ super().__init__(
+ index_name=index_name,
+ index=param_table.index,
+ path=_find_nodes_in_paths(graph_template, index_name),
+ )
+
+
+class GroupingReplicator(ReplicatorBase[LabelType], Generic[IndexType, LabelType]):
+ r"""
+ Helper for rewriting the graph to group by a given index.
+
+ See Pipeline._build_series for context. Given a graph template, this makes a
+ transformation as follows:
+
+ P1[0] P1[1] P1[2] P1[0] P1[1] P1[2]
+ | | | | | |
+ A[0] A[1] A[2] A[0] A[1] A[2]
+ | | | | | |
+ B[0] B[1] B[2] -> B[0] B[1] B[2]
+ \______|______/ \______/ |
+ | | |
+ SB SB[x] SB[y]
+ | | |
+ C C[x] C[y]
+
+ Where SB is Series[Idx,B]. Here, the upper half of the graph originates from a
+ prior transformation of a graph template using `Replicator`. The output of this
+ combined with further nodes is the graph template passed to this class. x and y
+ are the labels used in a grouping operation, based on the values of a ParamTable
+ column P2.
+ """
+
+ def __init__(
+ self, param_table: ParamTable, graph_template: Graph, label_name: type
+ ) -> None:
+ self._label_name = label_name
+ self._group_node = self._find_grouping_node(param_table.row_dim, graph_template)
+ self._groups: Dict[LabelType, List[IndexType]] = defaultdict(list)
+ for idx, label in zip(param_table.index, param_table[label_name]):
+ self._groups[label].append(idx)
+ super().__init__(
+ index_name=label_name,
+ index=self._groups,
+ path=_find_nodes_in_paths(graph_template, self._group_node),
+ )
+
+ def _copy_node(
+ self,
+ key: Key,
+ provider: Union[Provider, SeriesProvider[IndexType]],
+ args: Tuple[Key, ...],
+ idx: LabelType,
+ ) -> Tuple[Provider, Tuple[Key, ...]]:
+ if (not isinstance(provider, SeriesProvider)) or key != self._group_node:
+ return super()._copy_node(key, provider, args, idx)
+ labels = self._groups[idx]
+ if set(labels) - set(provider.labels):
+ raise ValueError(f'{labels} is not a subset of {provider.labels}')
+ selected = {
+ label: arg for label, arg in zip(provider.labels, args) if label in labels
+ }
+ split_provider = SeriesProvider(selected, provider.row_dim)
+ return (split_provider, tuple(selected.values()))
+
+ def _find_grouping_node(self, index_name: Key, subgraph: Graph) -> type:
+ ends: List[type] = []
+ for key in subgraph:
+ if get_origin(key) == Series and get_args(key)[0] == index_name:
+ # Because of the succeeded get_origin we know it is a type
+ ends.append(key) # type: ignore[arg-type]
+ if len(ends) == 1:
+ return ends[0]
+ raise ValueError(f"Could not find unique grouping node, found {ends}")
+
+
+class SeriesProvider(Generic[KeyType]):
+ """
+ Internal provider for combining results obtained based on different rows in a
+ param table into a single object.
+ """
+
+ def __init__(self, labels: Iterable[KeyType], row_dim: type) -> None:
+ self.labels = labels
+ self.row_dim = row_dim
+
+ def __call__(self, *vals: ValueType) -> Series[KeyType, ValueType]:
+ return Series(self.row_dim, dict(zip(self.labels, vals)))
+
+
+class _param_sentinel:
+ ...
+
+
class Pipeline:
"""A container for providers that can be assembled into a task graph."""
@@ -81,7 +279,7 @@ def __init__(
self,
providers: Optional[List[Provider]] = None,
*,
- params: Optional[Dict[Key, Any]] = None,
+ params: Optional[Dict[type, Any]] = None,
):
"""
Setup a Pipeline from a list providers
@@ -96,6 +294,8 @@ def __init__(
"""
self._providers: Dict[Key, Provider] = {}
self._subproviders: Dict[type, Dict[Tuple[Key | TypeVar, ...], Provider]] = {}
+ self._param_tables: Dict[Key, ParamTable] = {}
+ self._param_name_to_table_key: Dict[Key, Key] = {}
for provider in providers or []:
self.insert(provider)
for tp, param in (params or {}).items():
@@ -153,10 +353,48 @@ def __setitem__(self, key: Type[T], param: T) -> None:
)
self._set_provider(key, lambda: param)
- def _set_provider(self, key: Type[T], provider: Callable[..., T]) -> None:
+ def set_param_table(self, params: ParamTable) -> None:
+ """
+ Set a parameter table for a row dimension.
+
+ Values in the parameter table provide concrete values for a type given by the
+ respective column header.
+
+ A pipeline can have multiple parameter tables, but only one per row dimension.
+ Column names must be unique across all parameter tables.
+
+ Parameters
+ ----------
+ params:
+ Parameter table to set.
+ """
+ if params.row_dim in self._param_tables:
+ raise ValueError(f'Parameter table for {params.row_dim} already set')
+ for param_name in params:
+ if param_name in self._param_name_to_table_key:
+ raise ValueError(f'Parameter {param_name} already set')
+ self._param_tables[params.row_dim] = params
+ for param_name in params:
+ self._param_name_to_table_key[param_name] = params.row_dim
+ for param_name, values in params.items():
+ for index, label in zip(params.index, values):
+ self._set_provider(
+ Item((Label(tp=params.row_dim, index=index),), param_name),
+ lambda label=label: label,
+ )
+
+ def _set_provider(
+ self, key: Union[Type[T], Item[T]], provider: Callable[..., T]
+ ) -> None:
# isinstance does not work here and types.NoneType available only in 3.10+
if key == type(None): # noqa: E721
raise ValueError(f'Provider {provider} returning `None` is not allowed')
+ if get_origin(key) == Series:
+ raise ValueError(
+ f'Provider {provider} returning a sciline.Series is not allowed. '
+ 'Series is a special container reserved for use in conjunction with '
+ 'sciline.ParamTable and must not be provided directly.'
+ )
if (origin := get_origin(key)) is not None:
subproviders = self._subproviders.setdefault(origin, {})
args = get_args(key)
@@ -168,7 +406,9 @@ def _set_provider(self, key: Type[T], provider: Callable[..., T]) -> None:
raise ValueError(f'Provider for {key} already exists')
self._providers[key] = provider
- def _get_provider(self, tp: Type[T]) -> Tuple[Callable[..., T], Dict[TypeVar, Key]]:
+ def _get_provider(
+ self, tp: Union[Type[T], Item[T]]
+ ) -> Tuple[Callable[..., T], Dict[TypeVar, Key]]:
if (provider := self._providers.get(tp)) is not None:
return provider, {}
elif (origin := get_origin(tp)) is not None and (
@@ -192,7 +432,9 @@ def _get_provider(self, tp: Type[T]) -> Tuple[Callable[..., T], Dict[TypeVar, Ke
raise AmbiguousProvider("Multiple providers found for type", tp)
raise UnsatisfiedRequirement("No provider found for type", tp)
- def build(self, tp: Type[T], /) -> Graph:
+ def build(
+ self, tp: Union[Type[T], Item[T]], /, search_param_tables: bool = False
+ ) -> Graph:
"""
Return a dict of providers required for building the requested type `tp`.
@@ -206,25 +448,123 @@ def build(self, tp: Type[T], /) -> Graph:
----------
tp:
Type to build the graph for.
+ search_param_tables:
+ Whether to search parameter tables for concrete keys.
"""
graph: Graph = {}
- stack: List[Type[T]] = [tp]
+ stack: List[Union[Type[T], Item[T]]] = [tp]
while stack:
tp = stack.pop()
+ if search_param_tables and tp in self._param_name_to_table_key:
+ graph[tp] = (_param_sentinel, (self._param_name_to_table_key[tp],))
+ continue
+ if get_origin(tp) == Series:
+ graph.update(self._build_series(tp)) # type: ignore[arg-type]
+ continue
provider: Callable[..., T]
provider, bound = self._get_provider(tp)
tps = get_type_hints(provider)
- args = {
- name: _bind_free_typevars(t, bound=bound)
+ args = tuple(
+ _bind_free_typevars(t, bound=bound)
for name, t in tps.items()
if name != 'return'
- }
+ )
graph[tp] = (provider, args)
- for arg in args.values():
+ for arg in args:
if arg not in graph:
stack.append(arg)
return graph
+ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph:
+ """
+ Build (sub)graph for a Series type implementing ParamTable-based functionality.
+
+ We illustrate this with an example. Given a ParamTable with row_dim 'Idx':
+
+ Idx | P1 | P2
+ 0 | a | x
+ 1 | b | x
+ 2 | c | y
+
+ and providers for A depending on P1 and B depending on A. Calling
+ build(Series[Idx,B]) will call _build_series(Series[Idx,B]). This results in
+ the following procedure here:
+
+ 1. Call build(P1), resulting, e.g., in a graph S->A->B, where S is a sentinel.
+ The sentinel is used because build() cannot find a unique P1, since it is
+ not a single value but a column in a table.
+ 2. Instantiation of `Replicator`, which will be used to replicate the
+ relevant parts of the graph (see illustration there).
+ 3. Insert a special `SeriesProvider` node, which will gather the duplicates of
+ the 'B' node and provides the requested Series[Idx,B].
+ 4. Replicate the graph. Nodes that do not directly or indirectly depend on P1
+ are not replicated.
+
+ Conceptually, the final result will be {
+ 0: B(A(a)),
+ 1: B(A(b)),
+ 2: B(A(c))
+ }.
+
+ In more complex cases, we may be dealing with multiple levels of Series,
+ which is used for grouping operations. Consider the above example, but with
+ and additional provider for C depending on Series[Idx,B]. Calling
+ build(Series[P2,C]) will call _build_series(Series[P2,C]). This results in
+ the following procedure here:
+
+ a. Call build(C), which results in the procedure above, i.e., a nested call
+ to _build_series(Series[Idx,B]) and the resulting graph as explained above.
+ b. Instantiation of `GroupingReplicator`, which will be used to replicate the
+ relevant parts of the graph (see illustration there).
+ c. Insert a special `SeriesProvider` node, which will gather the duplicates of
+ the 'C' node and providers the requested Series[P2,C].
+ c. Replicate the graph. Nodes that do not directly or indirectly depend on
+ the special `SeriesProvider` node (from step 3.) are not replicated.
+
+ Conceptually, the final result will be {
+ x: C({
+ 0: B(A(a)),
+ 1: B(A(b))
+ }),
+ y: C({
+ 2: B(A(c))
+ })
+ }.
+ """
+ index_name: Type[KeyType]
+ value_type: Type[ValueType]
+ index_name, value_type = get_args(tp)
+
+ subgraph = self.build(value_type, search_param_tables=True)
+
+ replicator: ReplicatorBase[KeyType]
+ if (
+ index_name not in self._param_name_to_table_key
+ and (params := self._param_tables.get(index_name)) is not None
+ ):
+ replicator = Replicator(param_table=params, graph_template=subgraph)
+ elif (table_key := self._param_name_to_table_key.get(index_name)) is not None:
+ replicator = GroupingReplicator(
+ param_table=self._param_tables[table_key],
+ graph_template=subgraph,
+ label_name=index_name,
+ )
+ else:
+ raise KeyError(f'No parameter table found for label {index_name}')
+
+ graph: Graph = {}
+ graph[tp] = (
+ SeriesProvider(list(replicator.index), index_name),
+ tuple(replicator.key(idx, value_type) for idx in replicator.index),
+ )
+
+ for key, value in subgraph.items():
+ if key in replicator:
+ graph.update(replicator.replicate(key, value, self._get_provider))
+ else:
+ graph[key] = value
+ return graph
+
@overload
def compute(self, tp: Type[T]) -> T:
...
@@ -233,7 +573,11 @@ def compute(self, tp: Type[T]) -> T:
def compute(self, tp: Tuple[Type[T], ...]) -> Tuple[T, ...]:
...
- def compute(self, tp: type | Tuple[type, ...]) -> Any:
+ @overload
+ def compute(self, tp: Item[T]) -> T:
+ ...
+
+ def compute(self, tp: type | Tuple[type, ...] | Item[T]) -> Any:
"""
Compute result for the given keys.
@@ -264,7 +608,10 @@ def visualize(
return self.get(tp).visualize(**kwargs)
def get(
- self, keys: type | Tuple[type, ...], *, scheduler: Optional[Scheduler] = None
+ self,
+ keys: type | Tuple[type, ...] | Item[T],
+ *,
+ scheduler: Optional[Scheduler] = None,
) -> TaskGraph:
"""
Return a TaskGraph for the given keys.
diff --git a/src/sciline/scheduler.py b/src/sciline/scheduler.py
index c1b0c886..49213c84 100644
--- a/src/sciline/scheduler.py
+++ b/src/sciline/scheduler.py
@@ -1,12 +1,8 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
-from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
+from typing import Any, Callable, Dict, List, Optional, Protocol
-Key = type
-Graph = Dict[
- Key,
- Tuple[Callable[..., Any], Dict[str, Key]],
-]
+from sciline.typing import Graph, Key
class CycleError(Exception):
@@ -18,7 +14,7 @@ class Scheduler(Protocol):
Scheduler interface compatible with :py:class:`sciline.Pipeline`.
"""
- def get(self, graph: Graph, keys: List[type]) -> Any:
+ def get(self, graph: Graph, keys: List[Key]) -> Any:
"""
Compute the result for given keys from the graph.
@@ -37,21 +33,20 @@ class NaiveScheduler:
:py:class:`DaskScheduler` instead.
"""
- def get(self, graph: Graph, keys: List[type]) -> Any:
+ def get(self, graph: Graph, keys: List[Key]) -> Any:
import graphlib
- dependencies = {tp: set(args.values()) for tp, (_, args) in graph.items()}
+ dependencies = {tp: args for tp, (_, args) in graph.items()}
ts = graphlib.TopologicalSorter(dependencies)
try:
# Create list from generator to force early exception if there is a cycle
tasks = list(ts.static_order())
except graphlib.CycleError as e:
raise CycleError from e
- results: Dict[type, Any] = {}
+ results: Dict[Key, Any] = {}
for t in tasks:
provider, args = graph[t]
- args = {name: results[arg] for name, arg in args.items()}
- results[t] = provider(**args)
+ results[t] = provider(*[results[arg] for arg in args])
return tuple(results[key] for key in keys)
@@ -77,8 +72,8 @@ def __init__(self, scheduler: Optional[Callable[..., Any]] = None) -> None:
else:
self._dask_get = scheduler
- def get(self, graph: Graph, keys: List[type]) -> Any:
- dsk = {tp: (provider, *args.values()) for tp, (provider, args) in graph.items()}
+ def get(self, graph: Graph, keys: List[Key]) -> Any:
+ dsk = {tp: (provider, *args) for tp, (provider, args) in graph.items()}
try:
return self._dask_get(dsk, keys)
except RuntimeError as e:
diff --git a/src/sciline/series.py b/src/sciline/series.py
new file mode 100644
index 00000000..fb82ca0b
--- /dev/null
+++ b/src/sciline/series.py
@@ -0,0 +1,60 @@
+# SPDX-License-Identifier: BSD-3-Clause
+# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
+from __future__ import annotations
+
+from typing import Iterator, Mapping, Type, TypeVar
+
+Key = TypeVar('Key')
+Value = TypeVar('Value')
+
+
+class Series(Mapping[Key, Value]):
+ """A series of values with labels (row index) and named row dimension."""
+
+ def __init__(self, row_dim: Type[Key], items: Mapping[Key, Value]) -> None:
+ """
+ Create a new series.
+
+ Parameters
+ ----------
+ row_dim:
+ The row dimension. This must be a type or a type-alias (not an instance).
+ items:
+ The items of the series.
+ """
+ self._row_dim = row_dim
+ self._map: Mapping[Key, Value] = items
+
+ @property
+ def row_dim(self) -> type:
+ """The row dimension of the series."""
+ return self._row_dim
+
+ def __contains__(self, item: object) -> bool:
+ return item in self._map
+
+ def __iter__(self) -> Iterator[Key]:
+ return iter(self._map)
+
+ def __len__(self) -> int:
+ return len(self._map)
+
+ def __getitem__(self, key: Key) -> Value:
+ return self._map[key]
+
+ def __repr__(self) -> str:
+ return f"Series(row_dim={self.row_dim}, {self._map})"
+
+ def _repr_html_(self) -> str:
+ return (
+ f"{self.row_dim.__name__} | Value |
"
+ + "".join(
+ f"{k} | {v} |
" for k, v in self._map.items()
+ )
+ + "
"
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, Series):
+ return NotImplemented
+ return self.row_dim == other.row_dim and self._map == other._map
diff --git a/src/sciline/task_graph.py b/src/sciline/task_graph.py
index e832fcda..cebb12cb 100644
--- a/src/sciline/task_graph.py
+++ b/src/sciline/task_graph.py
@@ -2,9 +2,12 @@
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from __future__ import annotations
-from typing import Any, Optional, Tuple, Union
+from typing import Any, Optional, Tuple, TypeVar, Union
-from sciline.scheduler import DaskScheduler, Graph, NaiveScheduler, Scheduler
+from .scheduler import DaskScheduler, NaiveScheduler, Scheduler
+from .typing import Graph, Item
+
+T = TypeVar("T")
class TaskGraph:
@@ -19,7 +22,7 @@ def __init__(
self,
*,
graph: Graph,
- keys: Union[type, Tuple[type, ...]],
+ keys: Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]],
scheduler: Optional[Scheduler] = None,
) -> None:
self._graph = graph
@@ -31,7 +34,12 @@ def __init__(
scheduler = NaiveScheduler()
self._scheduler = scheduler
- def compute(self, keys: Optional[Union[type, Tuple[type, ...]]] = None) -> Any:
+ def compute(
+ self,
+ keys: Optional[
+ Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]]
+ ] = None,
+ ) -> Any:
"""
Compute the result of the graph.
diff --git a/src/sciline/typing.py b/src/sciline/typing.py
new file mode 100644
index 00000000..f6a941c3
--- /dev/null
+++ b/src/sciline/typing.py
@@ -0,0 +1,26 @@
+# SPDX-License-Identifier: BSD-3-Clause
+# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, Generic, Tuple, Type, TypeVar, Union
+
+
+@dataclass(frozen=True)
+class Label:
+ tp: type
+ index: Any
+
+
+T = TypeVar('T')
+
+
+@dataclass(frozen=True)
+class Item(Generic[T]):
+ label: Tuple[Label, ...]
+ tp: Type[T]
+
+
+Provider = Callable[..., Any]
+
+
+Key = Union[type, Item]
+Graph = Dict[Key, Tuple[Provider, Tuple[Key, ...]]]
diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py
index 32ed0709..f2c90e33 100644
--- a/src/sciline/visualize.py
+++ b/src/sciline/visualize.py
@@ -1,13 +1,32 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
-from typing import Any, Callable, Dict, List, Tuple, get_args, get_origin
+from dataclasses import dataclass
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ get_args,
+ get_origin,
+)
from graphviz import Digraph
-from .scheduler import Graph
+from .pipeline import Pipeline, SeriesProvider
+from .typing import Graph, Item, Key
-def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph:
+@dataclass
+class Node:
+ name: str
+ collapsed: bool = False
+
+
+def to_graphviz(graph: Graph, compact: bool = False, **kwargs: Any) -> Digraph:
"""
Convert output of :py:class:`sciline.Pipeline.get_graph` to a graphviz graph.
@@ -15,39 +34,70 @@ def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph:
----------
graph:
Output of :py:class:`sciline.Pipeline.get_graph`.
+ compact:
+ If True, parameter-table-dependent branches are collapsed into a single copy
+ of the branch. Recommendend for large graphs with long parameter tables.
kwargs:
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
dot = Digraph(strict=True, **kwargs)
- for p, (p_name, args, ret) in _format_graph(graph).items():
- dot.node(ret, ret, shape='rectangle')
+ for p, (p_name, args, ret) in _format_graph(graph, compact=compact).items():
+ dot.node(ret.name, ret.name, shape='box3d' if ret.collapsed else 'rectangle')
# Do not draw dummy providers created by Pipeline when setting instances
- if p_name == 'Pipeline.__setitem__..':
+ if p_name in (
+ f'{_qualname(Pipeline.__setitem__)}..',
+ f'{_qualname(Pipeline.set_param_table)}..',
+ ):
continue
- dot.node(p, p_name, shape='ellipse')
- for arg in args:
- dot.node(arg, arg, shape='rectangle')
- dot.edge(arg, p)
- dot.edge(p, ret)
+ # Do not draw the internal provider gathering index-dependent results into
+ # a dict
+ if p_name == _qualname(SeriesProvider):
+ for arg in args:
+ dot.edge(arg.name, ret.name, style='dashed')
+ else:
+ dot.node(p, p_name, shape='ellipse')
+ for arg in args:
+ dot.edge(arg.name, p)
+ dot.edge(p, ret.name)
return dot
-def _format_graph(graph: Graph) -> Dict[str, Tuple[str, List[str], str]]:
+def _qualname(obj: Any) -> Any:
+ return (
+ obj.__qualname__ if hasattr(obj, '__qualname__') else obj.__class__.__qualname__
+ )
+
+
+def _format_graph(
+ graph: Graph, compact: bool
+) -> Dict[str, Tuple[str, List[Node], Node]]:
return {
- _format_provider(provider, ret): (
- provider.__qualname__,
- [_format_type(a) for a in args.values()],
- _format_type(ret),
+ _format_provider(provider, ret, compact=compact): (
+ _qualname(provider),
+ [_format_type(a, compact=compact) for a in args],
+ _format_type(ret, compact=compact),
)
for ret, (provider, args) in graph.items()
}
-def _format_provider(provider: Callable[..., Any], ret: type) -> str:
- return f'{provider.__qualname__}_{_format_type(ret)}'
+def _format_provider(provider: Callable[..., Any], ret: Key, compact: bool) -> str:
+ return f'{_qualname(provider)}_{_format_type(ret, compact=compact).name}'
+
+
+T = TypeVar('T')
-def _format_type(tp: type) -> str:
+def _extract_type_and_labels(
+ key: Union[Item[T], Type[T]], compact: bool
+) -> Tuple[Type[T], List[Union[type, Tuple[type, Any]]]]:
+ if isinstance(key, Item):
+ label = key.label
+ return key.tp, [lb.tp if compact else (lb.tp, lb.index) for lb in label]
+ return key, []
+
+
+def _format_type(tp: Key, compact: bool = False) -> Node:
"""
Helper for _format_graph.
@@ -56,11 +106,27 @@ def _format_type(tp: type) -> str:
We may make this configurable in the future.
"""
+ tp, labels = _extract_type_and_labels(tp, compact=compact)
+
def get_base(tp: type) -> str:
return tp.__name__ if hasattr(tp, '__name__') else str(tp).split('.')[-1]
+ def format_label(label: Union[type, Tuple[type, Any]]) -> str:
+ if isinstance(label, tuple):
+ tp, index = label
+ return f'{get_base(tp)}={index}'
+ return get_base(label)
+
+ def with_labels(base: str) -> Node:
+ if labels:
+ return Node(
+ name=f'{base}({", ".join([format_label(l) for l in labels])})',
+ collapsed=compact,
+ )
+ return Node(name=base)
+
if (origin := get_origin(tp)) is not None:
- params = [_format_type(param) for param in get_args(tp)]
- return f'{get_base(origin)}[{", ".join(params)}]'
+ params = [_format_type(param).name for param in get_args(tp)]
+ return with_labels(f'{get_base(origin)}[{", ".join(params)}]')
else:
- return get_base(tp)
+ return with_labels(get_base(tp))
diff --git a/tests/param_table_test.py b/tests/param_table_test.py
new file mode 100644
index 00000000..4a57ad18
--- /dev/null
+++ b/tests/param_table_test.py
@@ -0,0 +1,47 @@
+# SPDX-License-Identifier: BSD-3-Clause
+# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
+import pytest
+
+import sciline as sl
+
+
+def test_raises_with_zero_columns() -> None:
+ with pytest.raises(ValueError):
+ sl.ParamTable(row_dim=int, columns={})
+
+
+def test_raises_with_inconsistent_column_sizes() -> None:
+ with pytest.raises(ValueError):
+ sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0]})
+
+
+def test_raises_with_inconsistent_index_length() -> None:
+ with pytest.raises(ValueError):
+ sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0]}, index=[1, 2, 3])
+
+
+def test_raises_with_non_unique_index() -> None:
+ with pytest.raises(ValueError):
+ sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[1, 1, 2])
+
+
+def test_contains_includes_all_columns() -> None:
+ pt = sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0, 3.0]})
+ assert int in pt
+ assert float in pt
+ assert str not in pt
+
+
+def test_contains_does_not_include_index() -> None:
+ pt = sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0, 3.0]})
+ assert int not in pt
+
+
+def test_len_is_number_of_columns() -> None:
+ pt = sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0, 3.0]})
+ assert len(pt) == 2
+
+
+def test_defaults_to_range_index() -> None:
+ pt = sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0, 3.0]})
+ assert pt.index == [0, 1, 2]
diff --git a/tests/pipeline_with_param_table_test.py b/tests/pipeline_with_param_table_test.py
new file mode 100644
index 00000000..88f8a718
--- /dev/null
+++ b/tests/pipeline_with_param_table_test.py
@@ -0,0 +1,583 @@
+# SPDX-License-Identifier: BSD-3-Clause
+# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
+from typing import List, NewType, Optional, TypeVar
+
+import pytest
+
+import sciline as sl
+from sciline.typing import Item, Label
+
+
+def test_set_param_table_raises_if_param_names_are_duplicate() -> None:
+ pl = sl.Pipeline()
+ pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
+ with pytest.raises(ValueError):
+ pl.set_param_table(sl.ParamTable(str, {float: [4.0, 5.0, 6.0]}))
+ assert pl.compute(Item((Label(int, 1),), float)) == 2.0
+ with pytest.raises(sl.UnsatisfiedRequirement):
+ pl.compute(Item((Label(str, 1),), float))
+
+
+def test_set_param_table_raises_if_row_dim_is_duplicate() -> None:
+ pl = sl.Pipeline()
+ pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
+ with pytest.raises(ValueError):
+ pl.set_param_table(sl.ParamTable(int, {str: ['a', 'b', 'c']}))
+ assert pl.compute(Item((Label(int, 1),), float)) == 2.0
+ with pytest.raises(sl.UnsatisfiedRequirement):
+ pl.compute(Item((Label(int, 1),), str))
+
+
+def test_can_get_elements_of_param_table() -> None:
+ pl = sl.Pipeline()
+ pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
+ assert pl.compute(Item((Label(int, 1),), float)) == 2.0
+
+
+def test_can_get_elements_of_param_table_with_explicit_index() -> None:
+ pl = sl.Pipeline()
+ pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[11, 12, 13]))
+ assert pl.compute(Item((Label(int, 12),), float)) == 2.0
+
+
+def test_can_depend_on_elements_of_param_table() -> None:
+ # This is not a valid type annotation, notsure why it works with get_type_hints
+ def use_elem(x: Item((Label(int, 1),), float)) -> str: # type: ignore[valid-type]
+ return str(x)
+
+ pl = sl.Pipeline([use_elem])
+ pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
+ assert pl.compute(str) == "2.0"
+
+
+def test_can_compute_series_of_param_values() -> None:
+ pl = sl.Pipeline()
+ pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
+ assert pl.compute(sl.Series[int, float]) == sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0})
+
+
+def test_can_compute_series_of_derived_values() -> None:
+ def process(x: float) -> str:
+ return str(x)
+
+ pl = sl.Pipeline([process])
+ pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
+ assert pl.compute(sl.Series[int, str]) == sl.Series(
+ int, {0: "1.0", 1: "2.0", 2: "3.0"}
+ )
+
+
+def test_creating_pipeline_with_provider_of_series_raises() -> None:
+ series = sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0})
+
+ def make_series() -> sl.Series[int, float]:
+ return series
+
+ with pytest.raises(ValueError):
+ sl.Pipeline([make_series])
+
+
+def test_creating_pipeline_with_series_param_raises() -> None:
+ series = sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0})
+
+ with pytest.raises(ValueError):
+ sl.Pipeline([], params={sl.Series[int, float]: series})
+
+
+def test_explicit_index_of_param_table_is_forwarded_correctly() -> None:
+ def process(x: float) -> int:
+ return int(x)
+
+ pl = sl.Pipeline([process])
+ pl.set_param_table(
+ sl.ParamTable(str, {float: [1.0, 2.0, 3.0]}, index=['a', 'b', 'c'])
+ )
+ assert pl.compute(sl.Series[str, int]) == sl.Series(str, {'a': 1, 'b': 2, 'c': 3})
+
+
+def test_can_gather_index() -> None:
+ Sum = NewType("Sum", float)
+ Name = NewType("Name", str)
+
+ def gather(x: sl.Series[Name, float]) -> Sum:
+ return Sum(sum(x.values()))
+
+ def make_float(x: str) -> float:
+ return float(x)
+
+ pl = sl.Pipeline([gather, make_float])
+ pl.set_param_table(sl.ParamTable(Name, {str: ["1.0", "2.0", "3.0"]}))
+ assert pl.compute(Sum) == 6.0
+
+
+def test_can_zip() -> None:
+ Sum = NewType("Sum", str)
+ Str = NewType("Str", str)
+ Run = NewType("Run", int)
+
+ def gather_zip(x: sl.Series[Run, Str], y: sl.Series[Run, int]) -> Sum:
+ z = [f'{x_}{y_}' for x_, y_ in zip(x.values(), y.values())]
+ return Sum(str(z))
+
+ def use_str(x: str) -> Str:
+ return Str(x)
+
+ pl = sl.Pipeline([gather_zip, use_str])
+ pl.set_param_table(sl.ParamTable(Run, {str: ['a', 'a', 'ccc'], int: [1, 2, 3]}))
+
+ assert pl.compute(Sum) == "['a1', 'a2', 'ccc3']"
+
+
+def test_diamond_dependency_pulls_values_from_columns_in_same_param_table() -> None:
+ Sum = NewType("Sum", float)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+ Row = NewType("Row", int)
+
+ def gather(x: sl.Series[Row, float]) -> Sum:
+ return Sum(sum(x.values()))
+
+ def join(x: Param1, y: Param2) -> float:
+ return x / y
+
+ pl = sl.Pipeline([gather, join])
+ pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 4, 9], Param2: [1, 2, 3]}))
+
+ assert pl.compute(Sum) == Sum(6)
+
+
+def test_diamond_dependency_on_same_column() -> None:
+ Sum = NewType("Sum", float)
+ Param = NewType("Param", int)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+ Row = NewType("Row", int)
+
+ def gather(x: sl.Series[Row, float]) -> Sum:
+ return Sum(sum(x.values()))
+
+ def to_param1(x: Param) -> Param1:
+ return Param1(x)
+
+ def to_param2(x: Param) -> Param2:
+ return Param2(x)
+
+ def join(x: Param1, y: Param2) -> float:
+ return x / y
+
+ pl = sl.Pipeline([gather, join, to_param1, to_param2])
+ pl.set_param_table(sl.ParamTable(Row, {Param: [1, 2, 3]}))
+
+ assert pl.compute(Sum) == Sum(3)
+
+
+def test_dependencies_on_different_param_tables_broadcast() -> None:
+ Row1 = NewType("Row1", int)
+ Row2 = NewType("Row2", int)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+ Product = NewType("Product", str)
+
+ def gather_both(x: sl.Series[Row1, Param1], y: sl.Series[Row2, Param2]) -> Product:
+ broadcast = [[x_, y_] for x_ in x.values() for y_ in y.values()]
+ return Product(str(broadcast))
+
+ pl = sl.Pipeline([gather_both])
+ pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]}))
+ pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]}))
+ assert pl.compute(Product) == "[[1, 4], [1, 5], [2, 4], [2, 5], [3, 4], [3, 5]]"
+
+
+def test_dependency_on_other_param_table_in_parent_broadcasts_branch() -> None:
+ Row1 = NewType("Row1", int)
+ Row2 = NewType("Row2", int)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+ Summed2 = NewType("Summed2", int)
+ Product = NewType("Product", str)
+
+ def gather2_and_combine(x: Param1, y: sl.Series[Row2, Param2]) -> Summed2:
+ return Summed2(x * sum(y.values()))
+
+ def gather1(x: sl.Series[Row1, Summed2]) -> Product:
+ return Product(str(list(x.values())))
+
+ pl = sl.Pipeline([gather1, gather2_and_combine])
+ pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]}))
+ pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]}))
+ assert pl.compute(Product) == "[9, 18, 27]"
+
+
+def test_dependency_on_other_param_table_in_grandparent_broadcasts_branch() -> None:
+ Row1 = NewType("Row1", int)
+ Row2 = NewType("Row2", int)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+ Summed2 = NewType("Summed2", int)
+ Combined = NewType("Combined", int)
+ Product = NewType("Product", str)
+
+ def gather2(x: sl.Series[Row2, Param2]) -> Summed2:
+ return Summed2(sum(x.values()))
+
+ def combine(x: Param1, y: Summed2) -> Combined:
+ return Combined(x * y)
+
+ def gather1(x: sl.Series[Row1, Combined]) -> Product:
+ return Product(str(list(x.values())))
+
+ pl = sl.Pipeline([gather1, gather2, combine])
+ pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]}))
+ pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]}))
+ assert pl.compute(Product) == "[9, 18, 27]"
+
+
+def test_nested_dependencies_on_different_param_tables() -> None:
+ Row1 = NewType("Row1", int)
+ Row2 = NewType("Row2", int)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+ Combined = NewType("Combined", int)
+
+ def combine(x: Param1, y: Param2) -> Combined:
+ return Combined(x * y)
+
+ pl = sl.Pipeline([combine])
+ pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]}))
+ pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]}))
+ assert pl.compute(sl.Series[Row1, sl.Series[Row2, Combined]]) == sl.Series(
+ Row1,
+ {
+ 0: sl.Series(Row2, {0: 4, 1: 5}),
+ 1: sl.Series(Row2, {0: 8, 1: 10}),
+ 2: sl.Series(Row2, {0: 12, 1: 15}),
+ },
+ )
+ assert pl.compute(sl.Series[Row2, sl.Series[Row1, Combined]]) == sl.Series(
+ Row2,
+ {
+ 0: sl.Series(Row1, {0: 4, 1: 8, 2: 12}),
+ 1: sl.Series(Row1, {0: 5, 1: 10, 2: 15}),
+ },
+ )
+
+
+def test_can_groupby_by_requesting_series_of_series() -> None:
+ Row = NewType("Row", int)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+
+ pl = sl.Pipeline()
+ pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]}))
+ expected = sl.Series(
+ Param1,
+ {1: sl.Series(Row, {0: 4, 1: 5}), 3: sl.Series(Row, {2: 6})},
+ )
+ assert pl.compute(sl.Series[Param1, sl.Series[Row, Param2]]) == expected
+
+
+def test_groupby_by_requesting_series_of_series_preserves_indices() -> None:
+ Row = NewType("Row", int)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+
+ pl = sl.Pipeline()
+ pl.set_param_table(
+ sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]}, index=[11, 12, 13])
+ )
+ assert pl.compute(sl.Series[Param1, sl.Series[Row, Param2]]) == sl.Series(
+ Param1, {1: sl.Series(Row, {11: 4, 12: 5}), 3: sl.Series(Row, {13: 6})}
+ )
+
+
+def test_can_groupby_by_param_used_in_ancestor() -> None:
+ Row = NewType("Row", int)
+ Param = NewType("Param", str)
+
+ pl = sl.Pipeline()
+ pl.set_param_table(sl.ParamTable(Row, {Param: ['x', 'x', 'y']}))
+ expected = sl.Series(
+ Param,
+ {"x": sl.Series(Row, {0: "x", 1: "x"}), "y": sl.Series(Row, {2: "y"})},
+ )
+ assert pl.compute(sl.Series[Param, sl.Series[Row, Param]]) == expected
+
+
+def test_multi_level_groupby_raises_with_params_from_same_table() -> None:
+ Row = NewType("Row", int)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+ Param3 = NewType("Param3", int)
+
+ pl = sl.Pipeline()
+ pl.set_param_table(
+ sl.ParamTable(
+ Row, {Param1: [1, 1, 1, 3], Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}
+ )
+ )
+ with pytest.raises(ValueError, match='Could not find unique grouping node'):
+ pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]])
+
+
+def test_multi_level_groupby_with_params_from_different_table() -> None:
+ Row = NewType("Row", int)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+ Param3 = NewType("Param3", int)
+
+ pl = sl.Pipeline()
+ grouping1 = sl.ParamTable(Row, {Param2: [0, 1, 1, 2], Param3: [7, 8, 9, 10]})
+ # We are not providing an explicit index here, so this only happens to work because
+ # the values of Param2 match range(2).
+ grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]})
+ pl.set_param_table(grouping1)
+ pl.set_param_table(grouping2)
+ assert pl.compute(
+ sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]
+ ) == sl.Series(
+ Param1,
+ {
+ 1: sl.Series(
+ Param2, {0: sl.Series(Row, {0: 7}), 1: sl.Series(Row, {1: 8, 2: 9})}
+ ),
+ 3: sl.Series(Param2, {2: sl.Series(Row, {3: 10})}),
+ },
+ )
+
+
+def test_multi_level_groupby_with_params_from_different_table_can_select() -> None:
+ Row = NewType("Row", int)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+ Param3 = NewType("Param3", int)
+
+ pl = sl.Pipeline()
+ grouping1 = sl.ParamTable(Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]})
+ # Note the missing index "6" here.
+ grouping2 = sl.ParamTable(Param2, {Param1: [1, 1]}, index=[4, 5])
+ pl.set_param_table(grouping1)
+ pl.set_param_table(grouping2)
+ assert pl.compute(
+ sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]
+ ) == sl.Series(
+ Param1,
+ {
+ 1: sl.Series(
+ Param2, {4: sl.Series(Row, {0: 7}), 5: sl.Series(Row, {1: 8, 2: 9})}
+ )
+ },
+ )
+
+
+def test_multi_level_groupby_with_params_from_different_table_preserves_index() -> None:
+ Row = NewType("Row", int)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+ Param3 = NewType("Param3", int)
+
+ pl = sl.Pipeline()
+ grouping1 = sl.ParamTable(
+ Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400]
+ )
+ grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=[4, 5, 6])
+ pl.set_param_table(grouping1)
+ pl.set_param_table(grouping2)
+ assert pl.compute(
+ sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]
+ ) == sl.Series(
+ Param1,
+ {
+ 1: sl.Series(
+ Param2,
+ {4: sl.Series(Row, {100: 7}), 5: sl.Series(Row, {200: 8, 300: 9})},
+ ),
+ 3: sl.Series(Param2, {6: sl.Series(Row, {400: 10})}),
+ },
+ )
+
+
+def test_multi_level_groupby_with_params_from_different_table_can_reorder() -> None:
+ Row = NewType("Row", int)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+ Param3 = NewType("Param3", int)
+
+ pl = sl.Pipeline()
+ grouping1 = sl.ParamTable(
+ Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400]
+ )
+ grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=[6, 5, 4])
+ pl.set_param_table(grouping1)
+ pl.set_param_table(grouping2)
+ assert pl.compute(
+ sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]
+ ) == sl.Series(
+ Param1,
+ {
+ 1: sl.Series(
+ Param2,
+ {6: sl.Series(Row, {400: 10}), 5: sl.Series(Row, {200: 8, 300: 9})},
+ ),
+ 3: sl.Series(Param2, {4: sl.Series(Row, {100: 7})}),
+ },
+ )
+
+
+@pytest.mark.parametrize("index", [None, [4, 5, 7]])
+def test_multi_level_groupby_raises_on_index_mismatch(
+ index: Optional[List[int]],
+) -> None:
+ Row = NewType("Row", int)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+ Param3 = NewType("Param3", int)
+
+ pl = sl.Pipeline()
+ grouping1 = sl.ParamTable(
+ Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400]
+ )
+ grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=index)
+ pl.set_param_table(grouping1)
+ pl.set_param_table(grouping2)
+ with pytest.raises(ValueError):
+ pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]])
+
+
+@pytest.mark.parametrize("index", [None, [4, 5, 6]])
+def test_groupby_over_param_table(index: Optional[List[int]]) -> None:
+ Index = NewType("Index", int)
+ Name = NewType("Name", str)
+ Param = NewType("Param", int)
+ ProcessedParam = NewType("ProcessedParam", int)
+ SummedGroup = NewType("SummedGroup", int)
+ ProcessedGroup = NewType("ProcessedGroup", int)
+
+ def process_param(x: Param) -> ProcessedParam:
+ return ProcessedParam(x + 1)
+
+ def sum_group(group: sl.Series[Index, ProcessedParam]) -> SummedGroup:
+ return SummedGroup(sum(group.values()))
+
+ def process(x: SummedGroup) -> ProcessedGroup:
+ return ProcessedGroup(2 * x)
+
+ params = sl.ParamTable(
+ Index, {Param: [1, 2, 3], Name: ['a', 'a', 'b']}, index=index
+ )
+ pl = sl.Pipeline([process_param, sum_group, process])
+ pl.set_param_table(params)
+
+ graph = pl.get(sl.Series[Name, ProcessedGroup])
+ assert graph.compute() == sl.Series(Name, {'a': 10, 'b': 8})
+
+
+def test_requesting_series_index_that_is_not_in_param_table_raises() -> None:
+ Row = NewType("Row", int)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+
+ pl = sl.Pipeline()
+ pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]}))
+
+ with pytest.raises(KeyError):
+ pl.compute(sl.Series[int, Param2])
+
+
+def test_requesting_series_index_that_is_a_param_raises_if_not_grouping() -> None:
+ Row = NewType("Row", int)
+ Param1 = NewType("Param1", int)
+ Param2 = NewType("Param2", int)
+
+ pl = sl.Pipeline()
+ pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]}))
+ with pytest.raises(ValueError, match='Could not find unique grouping node'):
+ pl.compute(sl.Series[Param1, Param2])
+
+
+def test_generic_providers_work_with_param_tables() -> None:
+ Param = TypeVar('Param')
+ Row = NewType("Row", int)
+
+ class Str(sl.Scope[Param, str], str):
+ ...
+
+ def parametrized(x: Param) -> Str[Param]:
+ return Str(f'{x}')
+
+ def make_float() -> float:
+ return 1.5
+
+ pipeline = sl.Pipeline([make_float, parametrized])
+ pipeline.set_param_table(sl.ParamTable(Row, {int: [1, 2, 3]}))
+
+ assert pipeline.compute(Str[float]) == Str[float]('1.5')
+ with pytest.raises(sl.UnsatisfiedRequirement):
+ pipeline.compute(Str[int])
+ assert pipeline.compute(sl.Series[Row, Str[int]]) == sl.Series(
+ Row,
+ {
+ 0: Str[int]('1'),
+ 1: Str[int]('2'),
+ 2: Str[int]('3'),
+ },
+ )
+
+
+def test_generic_provider_can_depend_on_param_series() -> None:
+ Param = TypeVar('Param')
+ Row = NewType("Row", int)
+
+ class Str(sl.Scope[Param, str], str):
+ ...
+
+ def parametrized_gather(x: sl.Series[Row, Param]) -> Str[Param]:
+ return Str(f'{list(x.values())}')
+
+ pipeline = sl.Pipeline([parametrized_gather])
+ pipeline.set_param_table(
+ sl.ParamTable(Row, {int: [1, 2, 3], float: [1.5, 2.5, 3.5]})
+ )
+
+ assert pipeline.compute(Str[int]) == Str[int]('[1, 2, 3]')
+ assert pipeline.compute(Str[float]) == Str[float]('[1.5, 2.5, 3.5]')
+
+
+def test_generic_provider_can_depend_on_derived_param_series() -> None:
+ T = TypeVar('T')
+ Row = NewType("Row", int)
+
+ class Str(sl.Scope[T, str], str):
+ ...
+
+ def use_param(x: int) -> float:
+ return x + 0.5
+
+ def parametrized_gather(x: sl.Series[Row, T]) -> Str[T]:
+ return Str(f'{list(x.values())}')
+
+ pipeline = sl.Pipeline([parametrized_gather, use_param])
+ pipeline.set_param_table(sl.ParamTable(Row, {int: [1, 2, 3]}))
+
+ assert pipeline.compute(Str[float]) == Str[float]('[1.5, 2.5, 3.5]')
+
+
+def test_params_in_table_can_be_generic() -> None:
+ T = TypeVar('T')
+ Row = NewType("Row", int)
+
+ class Str(sl.Scope[T, str], str):
+ ...
+
+ class Param(sl.Scope[T, str], str):
+ ...
+
+ def parametrized_gather(x: sl.Series[Row, Param[T]]) -> Str[T]:
+ return Str(','.join(x.values()))
+
+ pipeline = sl.Pipeline([parametrized_gather])
+ pipeline.set_param_table(
+ sl.ParamTable(Row, {Param[int]: ["1", "2"], Param[float]: ["1.5", "2.5"]})
+ )
+
+ assert pipeline.compute(Str[int]) == Str[int]('1,2')
+ assert pipeline.compute(Str[float]) == Str[float]('1.5,2.5')
diff --git a/tests/task_graph_test.py b/tests/task_graph_test.py
index 0f376049..31e24714 100644
--- a/tests/task_graph_test.py
+++ b/tests/task_graph_test.py
@@ -3,8 +3,8 @@
import pytest
import sciline as sl
-from sciline.scheduler import Graph
from sciline.task_graph import TaskGraph
+from sciline.typing import Graph
def as_float(x: int) -> float:
diff --git a/tests/visualize_test.py b/tests/visualize_test.py
index 663c6dfe..8290cbc7 100644
--- a/tests/visualize_test.py
+++ b/tests/visualize_test.py
@@ -30,6 +30,6 @@ class B(Generic[T]):
class SubA(A[T]):
pass
- assert sl.visualize._format_type(A[float]) == 'A[float]'
- assert sl.visualize._format_type(SubA[float]) == 'SubA[float]'
- assert sl.visualize._format_type(B[float]) == 'B[float]'
+ assert sl.visualize._format_type(A[float]).name == 'A[float]'
+ assert sl.visualize._format_type(SubA[float]).name == 'SubA[float]'
+ assert sl.visualize._format_type(B[float]).name == 'B[float]'