diff --git a/docs/api-reference/index.md b/docs/api-reference/index.md index 883ecd98..0cba8d9c 100644 --- a/docs/api-reference/index.md +++ b/docs/api-reference/index.md @@ -10,8 +10,10 @@ :template: class-template.rst :recursive: + ParamTable Pipeline Scope + Series scheduler.Scheduler scheduler.DaskScheduler scheduler.NaiveScheduler diff --git a/docs/user-guide/getting-started.ipynb b/docs/user-guide/getting-started.ipynb index f9d08cc4..981f0d2a 100644 --- a/docs/user-guide/getting-started.ipynb +++ b/docs/user-guide/getting-started.ipynb @@ -101,6 +101,9 @@ "from typing import NewType\n", "import sciline\n", "\n", + "_fake_filesytem = {'data.txt': [1, 2, float('nan'), 3]}\n", + "\n", + "\n", "# 1. Define domain types\n", "\n", "Filename = NewType('Filename', str)\n", @@ -115,19 +118,20 @@ "\n", "def load(filename: Filename) -> RawData:\n", " \"\"\"Load the data from the filename.\"\"\"\n", - " return {'data': [1, 2, float('nan'), 3], 'meta': {'filename': filename}}\n", + " data = _fake_filesytem[filename]\n", + " return RawData({'data': data, 'meta': {'filename': filename}})\n", "\n", "\n", "def clean(raw_data: RawData) -> CleanedData:\n", " \"\"\"Clean the data, removing NaNs.\"\"\"\n", " import math\n", "\n", - " return [x for x in raw_data['data'] if not math.isnan(x)]\n", + " return CleanedData([x for x in raw_data['data'] if not math.isnan(x)])\n", "\n", "\n", "def process(data: CleanedData, param: ScaleFactor) -> Result:\n", " \"\"\"Process the data, multiplying the sum by the scale factor.\"\"\"\n", - " return sum(data) * param\n", + " return Result(sum(data) * param)\n", "\n", "\n", "# 3. Create pipeline\n", diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md index ca13f1a2..3baa2628 100644 --- a/docs/user-guide/index.md +++ b/docs/user-guide/index.md @@ -6,4 +6,5 @@ maxdepth: 2 --- getting-started +parameter-tables ``` diff --git a/docs/user-guide/parameter-tables.ipynb b/docs/user-guide/parameter-tables.ipynb new file mode 100644 index 00000000..8ed44e3b --- /dev/null +++ b/docs/user-guide/parameter-tables.ipynb @@ -0,0 +1,486 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parameter Tables\n", + "\n", + "## Overview\n", + "\n", + "Parameter tables provide a mechanism for repeating parts of or all of a computation with different values for one or more parameters.\n", + "This allows for a variety of use cases, similar to *map*, *reduce*, and *groupby* operations in other systems.\n", + "We illustrate each of these in the follow three chapters." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Computing results for series of parameters\n", + "\n", + "This chapter illustrates how to implement *map* operations with Sciline.\n", + "\n", + "Starting with the model workflow introduced in [Getting Started](getting-started.ipynb), we can replace the fixed `Filename` parameter with a series of filenames listed in a [ParamTable](../generated/classes/ParamTable.rst):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import NewType\n", + "import sciline\n", + "\n", + "_fake_filesytem = {\n", + " 'file102.txt': [1, 2, float('nan'), 3],\n", + " 'file103.txt': [1, 2, 3, 4],\n", + " 'file104.txt': [1, 2, 3, 4, 5],\n", + " 'file105.txt': [1, 2, 3],\n", + "}\n", + "\n", + "# 1. Define domain types\n", + "\n", + "Filename = NewType('Filename', str)\n", + "RawData = NewType('RawData', dict)\n", + "CleanedData = NewType('CleanedData', list)\n", + "ScaleFactor = NewType('ScaleFactor', float)\n", + "Result = NewType('Result', float)\n", + "\n", + "\n", + "# 2. Define providers\n", + "\n", + "\n", + "def load(filename: Filename) -> RawData:\n", + " \"\"\"Load the data from the filename.\"\"\"\n", + "\n", + " data = _fake_filesytem[filename]\n", + " return RawData({'data': data, 'meta': {'filename': filename}})\n", + "\n", + "\n", + "def clean(raw_data: RawData) -> CleanedData:\n", + " \"\"\"Clean the data, removing NaNs.\"\"\"\n", + " import math\n", + "\n", + " return CleanedData([x for x in raw_data['data'] if not math.isnan(x)])\n", + "\n", + "\n", + "def process(data: CleanedData, param: ScaleFactor) -> Result:\n", + " \"\"\"Process the data, multiplying the sum by the scale factor.\"\"\"\n", + " return Result(sum(data) * param)\n", + "\n", + "\n", + "# 3. Create pipeline\n", + "\n", + "# 3.a Providers and normal parameters\n", + "providers = [load, clean, process]\n", + "params = {ScaleFactor: 2.0}\n", + "\n", + "# 3.b Parameter table\n", + "RunID = NewType('RunID', int)\n", + "run_ids = [102, 103, 104, 105]\n", + "filenames = [f'file{i}.txt' for i in run_ids]\n", + "param_table = sciline.ParamTable(RunID, {Filename: filenames}, index=run_ids)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note how steps 1.) and 2.) are identical to those from the example without parameter table.\n", + "Above we have created the following parameter table:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "param_table" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now create the pipeline and set the parameter table:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 3.c Setup pipeline\n", + "pipeline = sciline.Pipeline(providers, params=params)\n", + "pipeline.set_param_table(param_table)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we can compute `Result` for each index in the parameter table:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.compute(sciline.Series[RunID, Result])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "`sciline.Series` is a special `dict`-like type that signals to Sciline that the values of the series are based on values from one or more columns of a parameter table.\n", + "The parameter table is identified using the first argument to `Series`, in this case `RunID`.\n", + "The second argument specifies the result to be computed.\n", + "\n", + "We can also visualize the task graph for computing the series of `Result` values:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.visualize(sciline.Series[RunID, Result])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Nodes that depend on values from a parameter table are drawn as \"3D boxes\" with the parameter index name (the row dimension of the parameter table) and value given in parenthesis.\n", + "The dashed arrow indicates and internal transformation that gathers result from each branch and combines them into a single output, here `Series[RunID, Result]`.\n", + "\n", + "
\n", + "\n", + "Note\n", + "\n", + "With long parameter tables, graphs can get messy and hard to read.\n", + "Try using `visualize(..., compact=True)`.\n", + "\n", + "The `compact=True` option to yields a much more compact representation.\n", + "Instead of drawing every intermediate result and provider for each parameter, we then represent each parameter-dependent result as a single \"3D box\" node, representing all nodes for different values of the respective parameter.\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Combining intermediate results from series of parameters\n", + "\n", + "This chapter illustrates how to implement *reduce* operations with Sciline.\n", + "\n", + "Instead of requesting a series of results as above, we can also build pipelines with providers that depend on such series.\n", + "We can create a new pipeline, or extend the existing one by inserting a new provider:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MergedResult = NewType('MergedResult', float)\n", + "\n", + "\n", + "def merge_runs(runs: sciline.Series[RunID, Result]) -> MergedResult:\n", + " return MergedResult(sum(runs.values()))\n", + "\n", + "\n", + "pipeline.insert(merge_runs)\n", + "graph = pipeline.get(MergedResult)\n", + "graph.visualize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that this is identical to the example in the previous section, except for the last two nodes in the graph.\n", + "The computation now returns a single result:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "graph.compute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is useful if we need to continue computation after gathering results without setting up a second pipeline." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Grouping intermediate results based on secondary parameters\n", + "\n", + "This chapter illustrates how to implement *groupby* operations with Sciline.\n", + "\n", + "Continuing from the examples for *map* and *reduce*, we can introduce a secondary parameter in the table, such as the material of the sample:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Material = NewType('Material', str)\n", + "\n", + "# 3.a Providers and normal parameters\n", + "providers = [load, clean, process, merge_runs]\n", + "params = {ScaleFactor: 2.0}\n", + "\n", + "# 3.b Parameter table\n", + "run_ids = [102, 103, 104, 105]\n", + "sample = ['diamond', 'graphite', 'graphite', 'graphite']\n", + "filenames = [f'file{i}.txt' for i in run_ids]\n", + "param_table = sciline.ParamTable(\n", + " RunID, {Filename: filenames, Material: sample}, index=run_ids\n", + ")\n", + "param_table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 3.c Setup pipeline\n", + "pipeline = sciline.Pipeline(providers, params=params)\n", + "pipeline.set_param_table(param_table)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now compute `MergedResult` for a series of \"materials\":" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.compute(sciline.Series[Material, MergedResult])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The computation looks as show below.\n", + "Note how the initial steps of the computation depend on the `RunID` parameter, while later steps depend on `Material`:\n", + "The files for each run ID have been grouped by their material and then merged:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.visualize(sciline.Series[Material, MergedResult])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## More examples" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Combining multiple parameters from same table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sciline as sl\n", + "\n", + "Sum = NewType(\"Sum\", float)\n", + "Param1 = NewType(\"Param1\", int)\n", + "Param2 = NewType(\"Param2\", int)\n", + "Row = NewType(\"Run\", int)\n", + "\n", + "\n", + "def gather(\n", + " x: sl.Series[Row, float],\n", + ") -> Sum:\n", + " return Sum(sum(x.values()))\n", + "\n", + "\n", + "def product(x: Param1, y: Param2) -> float:\n", + " return x / y\n", + "\n", + "\n", + "pl = sl.Pipeline([gather, product])\n", + "pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 4, 9], Param2: [1, 2, 3]}))\n", + "\n", + "pl.visualize(Sum)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pl.compute(Sum)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Diamond graphs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Sum = NewType(\"Sum\", float)\n", + "Param = NewType(\"Param\", int)\n", + "Param1 = NewType(\"Param1\", int)\n", + "Param2 = NewType(\"Param2\", int)\n", + "Row = NewType(\"Run\", int)\n", + "\n", + "\n", + "def gather(x: sl.Series[Row, float]) -> Sum:\n", + " return Sum(sum(x.values()))\n", + "\n", + "\n", + "def to_param1(x: Param) -> Param1:\n", + " return Param1(x)\n", + "\n", + "\n", + "def to_param2(x: Param) -> Param2:\n", + " return Param2(x)\n", + "\n", + "\n", + "def product(x: Param1, y: Param2) -> float:\n", + " return x * y\n", + "\n", + "\n", + "pl = sl.Pipeline([gather, product, to_param1, to_param2])\n", + "pl.set_param_table(sl.ParamTable(Row, {Param: [1, 2, 3]}))\n", + "pl.visualize(Sum)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Combining parameters from different tables" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sciline as sl\n", + "\n", + "List1 = NewType(\"List1\", float)\n", + "List2 = NewType(\"List2\", float)\n", + "Param1 = NewType(\"Param1\", int)\n", + "Param2 = NewType(\"Param2\", int)\n", + "Row1 = NewType(\"Row1\", int)\n", + "Row2 = NewType(\"Row2\", int)\n", + "\n", + "\n", + "def gather1(x: sl.Series[Row1, float]) -> List1:\n", + " return List1(list(x.values()))\n", + "\n", + "\n", + "def gather2(x: sl.Series[Row2, List1]) -> List2:\n", + " return List2(list(x.values()))\n", + "\n", + "\n", + "def product(x: Param1, y: Param2) -> float:\n", + " return x * y\n", + "\n", + "\n", + "pl = sl.Pipeline([gather1, gather2, product])\n", + "pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 4, 9]}))\n", + "pl.set_param_table(sl.ParamTable(Row2, {Param2: [1, 2]}))\n", + "\n", + "pl.visualize(List2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note how intermediates such as `float(Row1, Row2)` depend on two parameters, i.e., we are dealing with a 2-D array of branches in the graph." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pl.compute(List2)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index 2f43d450..eb74611c 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -10,15 +10,19 @@ __version__ = "0.0.0" from .domain import Scope +from .param_table import ParamTable from .pipeline import ( AmbiguousProvider, Pipeline, UnboundTypeVar, UnsatisfiedRequirement, ) +from .series import Series __all__ = [ "AmbiguousProvider", + "Series", + "ParamTable", "Pipeline", "Scope", "UnboundTypeVar", diff --git a/src/sciline/param_table.py b/src/sciline/param_table.py new file mode 100644 index 00000000..585f7e8e --- /dev/null +++ b/src/sciline/param_table.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +from typing import Any, Collection, Dict, Mapping, Optional + + +class ParamTable(Mapping[type, Collection[Any]]): + """A table of parameters with a row index and named row dimension.""" + + def __init__( + self, + row_dim: type, + columns: Dict[type, Collection[Any]], + *, + index: Optional[Collection[Any]] = None, + ): + """ + Create a new param table. + + Parameters + ---------- + row_dim: + The row dimension. This must be a type or a type-alias (not an instance), + and is used by :py:class:`sciline.Pipeline` to identify each parameter + table. + columns: + The columns of the table. The keys (column names) must be types or type- + aliases matching the values in the respective columns. + index: + The row index of the table. If not given, a default index will be + generated, as the integer range of the column length. + """ + sizes = set(len(v) for v in columns.values()) + if len(sizes) != 1: + raise ValueError( + f"Columns in param table must all have same size, got {sizes}" + ) + size = sizes.pop() + if index is not None: + if len(index) != size: + raise ValueError( + f"Index length not matching columns, got {len(index)} and {size}" + ) + if len(set(index)) != len(index): + raise ValueError(f"Index must be unique, got {index}") + self._row_dim = row_dim + self._columns = columns + self._index = index or list(range(size)) + + @property + def row_dim(self) -> type: + """The row dimension of the table.""" + return self._row_dim + + @property + def index(self) -> Collection[Any]: + """The row index of the table.""" + return self._index + + def __contains__(self, key: Any) -> bool: + return self._columns.__contains__(key) + + def __getitem__(self, key: Any) -> Any: + return self._columns.__getitem__(key) + + def __iter__(self) -> Any: + return self._columns.__iter__() + + def __len__(self) -> int: + return self._columns.__len__() + + def __repr__(self) -> str: + return f"ParamTable(row_dim={self.row_dim}, columns={self._columns})" + + def _repr_html_(self) -> str: + return ( + f"" + + "".join(f"" for k in self._columns.keys()) + + "" + + "".join( + f"" + "".join(f"" for v in row) + "" + for idx, row in zip(self.index, zip(*self._columns.values())) + ) + + "
{self.row_dim.__name__}{k.__name__}
{idx}{v}
" + ) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 4dbd104f..75d8a527 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -2,15 +2,21 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from __future__ import annotations +from collections import defaultdict from typing import ( Any, Callable, + Collection, Dict, + Generic, + Iterable, List, + Mapping, Optional, Tuple, Type, TypeVar, + Union, get_args, get_origin, get_type_hints, @@ -20,9 +26,16 @@ from sciline.task_graph import TaskGraph from .domain import Scope -from .scheduler import Graph, Scheduler +from .param_table import ParamTable +from .scheduler import Scheduler +from .series import Series +from .typing import Graph, Item, Key, Label, Provider T = TypeVar('T') +KeyType = TypeVar('KeyType') +ValueType = TypeVar('ValueType') +IndexType = TypeVar('IndexType') +LabelType = TypeVar('LabelType') class UnsatisfiedRequirement(Exception): @@ -39,10 +52,6 @@ class AmbiguousProvider(Exception): """Raised when multiple providers are found for a type.""" -Provider = Callable[..., Any] -Key = type - - def _is_compatible_type_tuple( requested: tuple[Key, ...], provided: tuple[Key | TypeVar, ...], @@ -74,6 +83,195 @@ def _bind_free_typevars(tp: TypeVar | Key, bound: Dict[TypeVar, Key]) -> Key: return tp +def _find_all_paths( + dependencies: Mapping[T, Collection[T]], start: T, end: T +) -> List[List[T]]: + """Find all paths from start to end in a DAG.""" + if start == end: + return [[start]] + if start not in dependencies: + return [] + paths = [] + for node in dependencies[start]: + for path in _find_all_paths(dependencies, node, end): + paths.append([start] + path) + return paths + + +def _find_nodes_in_paths( + graph: Mapping[T, Tuple[Callable[..., Any], Collection[T]]], end: T +) -> List[T]: + """ + Helper for Pipeline. Finds all nodes that need to be duplicated since they depend + on a value from a param table. + """ + start = next(iter(graph)) + # 0 is the provider, 1 is the args + dependencies = {k: v[1] for k, v in graph.items()} + paths = _find_all_paths(dependencies, start, end) + nodes = set() + for path in paths: + nodes.update(path) + return list(nodes) + + +class ReplicatorBase(Generic[IndexType]): + def __init__(self, index_name: type, index: Iterable[IndexType], path: List[Key]): + self._index_name = index_name + self.index = index + self._path = path + + def __contains__(self, key: Key) -> bool: + return key in self._path + + def replicate( + self, + key: Key, + value: Any, + get_provider: Callable[..., Tuple[Provider, Dict[TypeVar, Key]]], + ) -> Graph: + graph: Graph = {} + provider, args = value + for idx in self.index: + subkey = self.key(idx, key) + if provider == _param_sentinel: + graph[subkey] = (get_provider(subkey)[0], ()) + else: + graph[subkey] = self._copy_node(key, provider, args, idx) + return graph + + def _copy_node( + self, + key: Key, + provider: Union[Provider, SeriesProvider[IndexType]], + args: Tuple[Key, ...], + idx: IndexType, + ) -> Tuple[Provider, Tuple[Key, ...]]: + return ( + provider, + tuple(self.key(idx, arg) if arg in self else arg for arg in args), + ) + + def key(self, i: IndexType, value_name: Union[Type[T], Item[T]]) -> Item[T]: + label = Label(self._index_name, i) + if isinstance(value_name, Item): + return Item(value_name.label + (label,), value_name.tp) + else: + return Item((label,), value_name) + + +class Replicator(ReplicatorBase[IndexType]): + r""" + Helper for rewriting the graph to map over a given index. + + See Pipeline._build_series for context. Given a graph template, this makes a + transformation as follows: + + S P1[0] P1[1] P1[2] + | | | | + A -> A[0] A[1] A[2] + | | | | + B B[0] B[1] B[2] + + Where S is a sentinel value, P1 are parameters from a parameter table, 0,1,2 + are indices of the param table rows, and A and B are arbitrary nodes in the graph. + """ + + def __init__(self, param_table: ParamTable, graph_template: Graph) -> None: + index_name = param_table.row_dim + super().__init__( + index_name=index_name, + index=param_table.index, + path=_find_nodes_in_paths(graph_template, index_name), + ) + + +class GroupingReplicator(ReplicatorBase[LabelType], Generic[IndexType, LabelType]): + r""" + Helper for rewriting the graph to group by a given index. + + See Pipeline._build_series for context. Given a graph template, this makes a + transformation as follows: + + P1[0] P1[1] P1[2] P1[0] P1[1] P1[2] + | | | | | | + A[0] A[1] A[2] A[0] A[1] A[2] + | | | | | | + B[0] B[1] B[2] -> B[0] B[1] B[2] + \______|______/ \______/ | + | | | + SB SB[x] SB[y] + | | | + C C[x] C[y] + + Where SB is Series[Idx,B]. Here, the upper half of the graph originates from a + prior transformation of a graph template using `Replicator`. The output of this + combined with further nodes is the graph template passed to this class. x and y + are the labels used in a grouping operation, based on the values of a ParamTable + column P2. + """ + + def __init__( + self, param_table: ParamTable, graph_template: Graph, label_name: type + ) -> None: + self._label_name = label_name + self._group_node = self._find_grouping_node(param_table.row_dim, graph_template) + self._groups: Dict[LabelType, List[IndexType]] = defaultdict(list) + for idx, label in zip(param_table.index, param_table[label_name]): + self._groups[label].append(idx) + super().__init__( + index_name=label_name, + index=self._groups, + path=_find_nodes_in_paths(graph_template, self._group_node), + ) + + def _copy_node( + self, + key: Key, + provider: Union[Provider, SeriesProvider[IndexType]], + args: Tuple[Key, ...], + idx: LabelType, + ) -> Tuple[Provider, Tuple[Key, ...]]: + if (not isinstance(provider, SeriesProvider)) or key != self._group_node: + return super()._copy_node(key, provider, args, idx) + labels = self._groups[idx] + if set(labels) - set(provider.labels): + raise ValueError(f'{labels} is not a subset of {provider.labels}') + selected = { + label: arg for label, arg in zip(provider.labels, args) if label in labels + } + split_provider = SeriesProvider(selected, provider.row_dim) + return (split_provider, tuple(selected.values())) + + def _find_grouping_node(self, index_name: Key, subgraph: Graph) -> type: + ends: List[type] = [] + for key in subgraph: + if get_origin(key) == Series and get_args(key)[0] == index_name: + # Because of the succeeded get_origin we know it is a type + ends.append(key) # type: ignore[arg-type] + if len(ends) == 1: + return ends[0] + raise ValueError(f"Could not find unique grouping node, found {ends}") + + +class SeriesProvider(Generic[KeyType]): + """ + Internal provider for combining results obtained based on different rows in a + param table into a single object. + """ + + def __init__(self, labels: Iterable[KeyType], row_dim: type) -> None: + self.labels = labels + self.row_dim = row_dim + + def __call__(self, *vals: ValueType) -> Series[KeyType, ValueType]: + return Series(self.row_dim, dict(zip(self.labels, vals))) + + +class _param_sentinel: + ... + + class Pipeline: """A container for providers that can be assembled into a task graph.""" @@ -81,7 +279,7 @@ def __init__( self, providers: Optional[List[Provider]] = None, *, - params: Optional[Dict[Key, Any]] = None, + params: Optional[Dict[type, Any]] = None, ): """ Setup a Pipeline from a list providers @@ -96,6 +294,8 @@ def __init__( """ self._providers: Dict[Key, Provider] = {} self._subproviders: Dict[type, Dict[Tuple[Key | TypeVar, ...], Provider]] = {} + self._param_tables: Dict[Key, ParamTable] = {} + self._param_name_to_table_key: Dict[Key, Key] = {} for provider in providers or []: self.insert(provider) for tp, param in (params or {}).items(): @@ -153,10 +353,48 @@ def __setitem__(self, key: Type[T], param: T) -> None: ) self._set_provider(key, lambda: param) - def _set_provider(self, key: Type[T], provider: Callable[..., T]) -> None: + def set_param_table(self, params: ParamTable) -> None: + """ + Set a parameter table for a row dimension. + + Values in the parameter table provide concrete values for a type given by the + respective column header. + + A pipeline can have multiple parameter tables, but only one per row dimension. + Column names must be unique across all parameter tables. + + Parameters + ---------- + params: + Parameter table to set. + """ + if params.row_dim in self._param_tables: + raise ValueError(f'Parameter table for {params.row_dim} already set') + for param_name in params: + if param_name in self._param_name_to_table_key: + raise ValueError(f'Parameter {param_name} already set') + self._param_tables[params.row_dim] = params + for param_name in params: + self._param_name_to_table_key[param_name] = params.row_dim + for param_name, values in params.items(): + for index, label in zip(params.index, values): + self._set_provider( + Item((Label(tp=params.row_dim, index=index),), param_name), + lambda label=label: label, + ) + + def _set_provider( + self, key: Union[Type[T], Item[T]], provider: Callable[..., T] + ) -> None: # isinstance does not work here and types.NoneType available only in 3.10+ if key == type(None): # noqa: E721 raise ValueError(f'Provider {provider} returning `None` is not allowed') + if get_origin(key) == Series: + raise ValueError( + f'Provider {provider} returning a sciline.Series is not allowed. ' + 'Series is a special container reserved for use in conjunction with ' + 'sciline.ParamTable and must not be provided directly.' + ) if (origin := get_origin(key)) is not None: subproviders = self._subproviders.setdefault(origin, {}) args = get_args(key) @@ -168,7 +406,9 @@ def _set_provider(self, key: Type[T], provider: Callable[..., T]) -> None: raise ValueError(f'Provider for {key} already exists') self._providers[key] = provider - def _get_provider(self, tp: Type[T]) -> Tuple[Callable[..., T], Dict[TypeVar, Key]]: + def _get_provider( + self, tp: Union[Type[T], Item[T]] + ) -> Tuple[Callable[..., T], Dict[TypeVar, Key]]: if (provider := self._providers.get(tp)) is not None: return provider, {} elif (origin := get_origin(tp)) is not None and ( @@ -192,7 +432,9 @@ def _get_provider(self, tp: Type[T]) -> Tuple[Callable[..., T], Dict[TypeVar, Ke raise AmbiguousProvider("Multiple providers found for type", tp) raise UnsatisfiedRequirement("No provider found for type", tp) - def build(self, tp: Type[T], /) -> Graph: + def build( + self, tp: Union[Type[T], Item[T]], /, search_param_tables: bool = False + ) -> Graph: """ Return a dict of providers required for building the requested type `tp`. @@ -206,25 +448,123 @@ def build(self, tp: Type[T], /) -> Graph: ---------- tp: Type to build the graph for. + search_param_tables: + Whether to search parameter tables for concrete keys. """ graph: Graph = {} - stack: List[Type[T]] = [tp] + stack: List[Union[Type[T], Item[T]]] = [tp] while stack: tp = stack.pop() + if search_param_tables and tp in self._param_name_to_table_key: + graph[tp] = (_param_sentinel, (self._param_name_to_table_key[tp],)) + continue + if get_origin(tp) == Series: + graph.update(self._build_series(tp)) # type: ignore[arg-type] + continue provider: Callable[..., T] provider, bound = self._get_provider(tp) tps = get_type_hints(provider) - args = { - name: _bind_free_typevars(t, bound=bound) + args = tuple( + _bind_free_typevars(t, bound=bound) for name, t in tps.items() if name != 'return' - } + ) graph[tp] = (provider, args) - for arg in args.values(): + for arg in args: if arg not in graph: stack.append(arg) return graph + def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: + """ + Build (sub)graph for a Series type implementing ParamTable-based functionality. + + We illustrate this with an example. Given a ParamTable with row_dim 'Idx': + + Idx | P1 | P2 + 0 | a | x + 1 | b | x + 2 | c | y + + and providers for A depending on P1 and B depending on A. Calling + build(Series[Idx,B]) will call _build_series(Series[Idx,B]). This results in + the following procedure here: + + 1. Call build(P1), resulting, e.g., in a graph S->A->B, where S is a sentinel. + The sentinel is used because build() cannot find a unique P1, since it is + not a single value but a column in a table. + 2. Instantiation of `Replicator`, which will be used to replicate the + relevant parts of the graph (see illustration there). + 3. Insert a special `SeriesProvider` node, which will gather the duplicates of + the 'B' node and provides the requested Series[Idx,B]. + 4. Replicate the graph. Nodes that do not directly or indirectly depend on P1 + are not replicated. + + Conceptually, the final result will be { + 0: B(A(a)), + 1: B(A(b)), + 2: B(A(c)) + }. + + In more complex cases, we may be dealing with multiple levels of Series, + which is used for grouping operations. Consider the above example, but with + and additional provider for C depending on Series[Idx,B]. Calling + build(Series[P2,C]) will call _build_series(Series[P2,C]). This results in + the following procedure here: + + a. Call build(C), which results in the procedure above, i.e., a nested call + to _build_series(Series[Idx,B]) and the resulting graph as explained above. + b. Instantiation of `GroupingReplicator`, which will be used to replicate the + relevant parts of the graph (see illustration there). + c. Insert a special `SeriesProvider` node, which will gather the duplicates of + the 'C' node and providers the requested Series[P2,C]. + c. Replicate the graph. Nodes that do not directly or indirectly depend on + the special `SeriesProvider` node (from step 3.) are not replicated. + + Conceptually, the final result will be { + x: C({ + 0: B(A(a)), + 1: B(A(b)) + }), + y: C({ + 2: B(A(c)) + }) + }. + """ + index_name: Type[KeyType] + value_type: Type[ValueType] + index_name, value_type = get_args(tp) + + subgraph = self.build(value_type, search_param_tables=True) + + replicator: ReplicatorBase[KeyType] + if ( + index_name not in self._param_name_to_table_key + and (params := self._param_tables.get(index_name)) is not None + ): + replicator = Replicator(param_table=params, graph_template=subgraph) + elif (table_key := self._param_name_to_table_key.get(index_name)) is not None: + replicator = GroupingReplicator( + param_table=self._param_tables[table_key], + graph_template=subgraph, + label_name=index_name, + ) + else: + raise KeyError(f'No parameter table found for label {index_name}') + + graph: Graph = {} + graph[tp] = ( + SeriesProvider(list(replicator.index), index_name), + tuple(replicator.key(idx, value_type) for idx in replicator.index), + ) + + for key, value in subgraph.items(): + if key in replicator: + graph.update(replicator.replicate(key, value, self._get_provider)) + else: + graph[key] = value + return graph + @overload def compute(self, tp: Type[T]) -> T: ... @@ -233,7 +573,11 @@ def compute(self, tp: Type[T]) -> T: def compute(self, tp: Tuple[Type[T], ...]) -> Tuple[T, ...]: ... - def compute(self, tp: type | Tuple[type, ...]) -> Any: + @overload + def compute(self, tp: Item[T]) -> T: + ... + + def compute(self, tp: type | Tuple[type, ...] | Item[T]) -> Any: """ Compute result for the given keys. @@ -264,7 +608,10 @@ def visualize( return self.get(tp).visualize(**kwargs) def get( - self, keys: type | Tuple[type, ...], *, scheduler: Optional[Scheduler] = None + self, + keys: type | Tuple[type, ...] | Item[T], + *, + scheduler: Optional[Scheduler] = None, ) -> TaskGraph: """ Return a TaskGraph for the given keys. diff --git a/src/sciline/scheduler.py b/src/sciline/scheduler.py index c1b0c886..49213c84 100644 --- a/src/sciline/scheduler.py +++ b/src/sciline/scheduler.py @@ -1,12 +1,8 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple +from typing import Any, Callable, Dict, List, Optional, Protocol -Key = type -Graph = Dict[ - Key, - Tuple[Callable[..., Any], Dict[str, Key]], -] +from sciline.typing import Graph, Key class CycleError(Exception): @@ -18,7 +14,7 @@ class Scheduler(Protocol): Scheduler interface compatible with :py:class:`sciline.Pipeline`. """ - def get(self, graph: Graph, keys: List[type]) -> Any: + def get(self, graph: Graph, keys: List[Key]) -> Any: """ Compute the result for given keys from the graph. @@ -37,21 +33,20 @@ class NaiveScheduler: :py:class:`DaskScheduler` instead. """ - def get(self, graph: Graph, keys: List[type]) -> Any: + def get(self, graph: Graph, keys: List[Key]) -> Any: import graphlib - dependencies = {tp: set(args.values()) for tp, (_, args) in graph.items()} + dependencies = {tp: args for tp, (_, args) in graph.items()} ts = graphlib.TopologicalSorter(dependencies) try: # Create list from generator to force early exception if there is a cycle tasks = list(ts.static_order()) except graphlib.CycleError as e: raise CycleError from e - results: Dict[type, Any] = {} + results: Dict[Key, Any] = {} for t in tasks: provider, args = graph[t] - args = {name: results[arg] for name, arg in args.items()} - results[t] = provider(**args) + results[t] = provider(*[results[arg] for arg in args]) return tuple(results[key] for key in keys) @@ -77,8 +72,8 @@ def __init__(self, scheduler: Optional[Callable[..., Any]] = None) -> None: else: self._dask_get = scheduler - def get(self, graph: Graph, keys: List[type]) -> Any: - dsk = {tp: (provider, *args.values()) for tp, (provider, args) in graph.items()} + def get(self, graph: Graph, keys: List[Key]) -> Any: + dsk = {tp: (provider, *args) for tp, (provider, args) in graph.items()} try: return self._dask_get(dsk, keys) except RuntimeError as e: diff --git a/src/sciline/series.py b/src/sciline/series.py new file mode 100644 index 00000000..fb82ca0b --- /dev/null +++ b/src/sciline/series.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +from typing import Iterator, Mapping, Type, TypeVar + +Key = TypeVar('Key') +Value = TypeVar('Value') + + +class Series(Mapping[Key, Value]): + """A series of values with labels (row index) and named row dimension.""" + + def __init__(self, row_dim: Type[Key], items: Mapping[Key, Value]) -> None: + """ + Create a new series. + + Parameters + ---------- + row_dim: + The row dimension. This must be a type or a type-alias (not an instance). + items: + The items of the series. + """ + self._row_dim = row_dim + self._map: Mapping[Key, Value] = items + + @property + def row_dim(self) -> type: + """The row dimension of the series.""" + return self._row_dim + + def __contains__(self, item: object) -> bool: + return item in self._map + + def __iter__(self) -> Iterator[Key]: + return iter(self._map) + + def __len__(self) -> int: + return len(self._map) + + def __getitem__(self, key: Key) -> Value: + return self._map[key] + + def __repr__(self) -> str: + return f"Series(row_dim={self.row_dim}, {self._map})" + + def _repr_html_(self) -> str: + return ( + f"" + + "".join( + f"" for k, v in self._map.items() + ) + + "
{self.row_dim.__name__}Value
{k}{v}
" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Series): + return NotImplemented + return self.row_dim == other.row_dim and self._map == other._map diff --git a/src/sciline/task_graph.py b/src/sciline/task_graph.py index e832fcda..cebb12cb 100644 --- a/src/sciline/task_graph.py +++ b/src/sciline/task_graph.py @@ -2,9 +2,12 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from __future__ import annotations -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple, TypeVar, Union -from sciline.scheduler import DaskScheduler, Graph, NaiveScheduler, Scheduler +from .scheduler import DaskScheduler, NaiveScheduler, Scheduler +from .typing import Graph, Item + +T = TypeVar("T") class TaskGraph: @@ -19,7 +22,7 @@ def __init__( self, *, graph: Graph, - keys: Union[type, Tuple[type, ...]], + keys: Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]], scheduler: Optional[Scheduler] = None, ) -> None: self._graph = graph @@ -31,7 +34,12 @@ def __init__( scheduler = NaiveScheduler() self._scheduler = scheduler - def compute(self, keys: Optional[Union[type, Tuple[type, ...]]] = None) -> Any: + def compute( + self, + keys: Optional[ + Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]] + ] = None, + ) -> Any: """ Compute the result of the graph. diff --git a/src/sciline/typing.py b/src/sciline/typing.py new file mode 100644 index 00000000..f6a941c3 --- /dev/null +++ b/src/sciline/typing.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generic, Tuple, Type, TypeVar, Union + + +@dataclass(frozen=True) +class Label: + tp: type + index: Any + + +T = TypeVar('T') + + +@dataclass(frozen=True) +class Item(Generic[T]): + label: Tuple[Label, ...] + tp: Type[T] + + +Provider = Callable[..., Any] + + +Key = Union[type, Item] +Graph = Dict[Key, Tuple[Provider, Tuple[Key, ...]]] diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index 32ed0709..f2c90e33 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -1,13 +1,32 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import Any, Callable, Dict, List, Tuple, get_args, get_origin +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Dict, + List, + Tuple, + Type, + TypeVar, + Union, + get_args, + get_origin, +) from graphviz import Digraph -from .scheduler import Graph +from .pipeline import Pipeline, SeriesProvider +from .typing import Graph, Item, Key -def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph: +@dataclass +class Node: + name: str + collapsed: bool = False + + +def to_graphviz(graph: Graph, compact: bool = False, **kwargs: Any) -> Digraph: """ Convert output of :py:class:`sciline.Pipeline.get_graph` to a graphviz graph. @@ -15,39 +34,70 @@ def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph: ---------- graph: Output of :py:class:`sciline.Pipeline.get_graph`. + compact: + If True, parameter-table-dependent branches are collapsed into a single copy + of the branch. Recommendend for large graphs with long parameter tables. kwargs: Keyword arguments passed to :py:class:`graphviz.Digraph`. """ dot = Digraph(strict=True, **kwargs) - for p, (p_name, args, ret) in _format_graph(graph).items(): - dot.node(ret, ret, shape='rectangle') + for p, (p_name, args, ret) in _format_graph(graph, compact=compact).items(): + dot.node(ret.name, ret.name, shape='box3d' if ret.collapsed else 'rectangle') # Do not draw dummy providers created by Pipeline when setting instances - if p_name == 'Pipeline.__setitem__..': + if p_name in ( + f'{_qualname(Pipeline.__setitem__)}..', + f'{_qualname(Pipeline.set_param_table)}..', + ): continue - dot.node(p, p_name, shape='ellipse') - for arg in args: - dot.node(arg, arg, shape='rectangle') - dot.edge(arg, p) - dot.edge(p, ret) + # Do not draw the internal provider gathering index-dependent results into + # a dict + if p_name == _qualname(SeriesProvider): + for arg in args: + dot.edge(arg.name, ret.name, style='dashed') + else: + dot.node(p, p_name, shape='ellipse') + for arg in args: + dot.edge(arg.name, p) + dot.edge(p, ret.name) return dot -def _format_graph(graph: Graph) -> Dict[str, Tuple[str, List[str], str]]: +def _qualname(obj: Any) -> Any: + return ( + obj.__qualname__ if hasattr(obj, '__qualname__') else obj.__class__.__qualname__ + ) + + +def _format_graph( + graph: Graph, compact: bool +) -> Dict[str, Tuple[str, List[Node], Node]]: return { - _format_provider(provider, ret): ( - provider.__qualname__, - [_format_type(a) for a in args.values()], - _format_type(ret), + _format_provider(provider, ret, compact=compact): ( + _qualname(provider), + [_format_type(a, compact=compact) for a in args], + _format_type(ret, compact=compact), ) for ret, (provider, args) in graph.items() } -def _format_provider(provider: Callable[..., Any], ret: type) -> str: - return f'{provider.__qualname__}_{_format_type(ret)}' +def _format_provider(provider: Callable[..., Any], ret: Key, compact: bool) -> str: + return f'{_qualname(provider)}_{_format_type(ret, compact=compact).name}' + + +T = TypeVar('T') -def _format_type(tp: type) -> str: +def _extract_type_and_labels( + key: Union[Item[T], Type[T]], compact: bool +) -> Tuple[Type[T], List[Union[type, Tuple[type, Any]]]]: + if isinstance(key, Item): + label = key.label + return key.tp, [lb.tp if compact else (lb.tp, lb.index) for lb in label] + return key, [] + + +def _format_type(tp: Key, compact: bool = False) -> Node: """ Helper for _format_graph. @@ -56,11 +106,27 @@ def _format_type(tp: type) -> str: We may make this configurable in the future. """ + tp, labels = _extract_type_and_labels(tp, compact=compact) + def get_base(tp: type) -> str: return tp.__name__ if hasattr(tp, '__name__') else str(tp).split('.')[-1] + def format_label(label: Union[type, Tuple[type, Any]]) -> str: + if isinstance(label, tuple): + tp, index = label + return f'{get_base(tp)}={index}' + return get_base(label) + + def with_labels(base: str) -> Node: + if labels: + return Node( + name=f'{base}({", ".join([format_label(l) for l in labels])})', + collapsed=compact, + ) + return Node(name=base) + if (origin := get_origin(tp)) is not None: - params = [_format_type(param) for param in get_args(tp)] - return f'{get_base(origin)}[{", ".join(params)}]' + params = [_format_type(param).name for param in get_args(tp)] + return with_labels(f'{get_base(origin)}[{", ".join(params)}]') else: - return get_base(tp) + return with_labels(get_base(tp)) diff --git a/tests/param_table_test.py b/tests/param_table_test.py new file mode 100644 index 00000000..4a57ad18 --- /dev/null +++ b/tests/param_table_test.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +import pytest + +import sciline as sl + + +def test_raises_with_zero_columns() -> None: + with pytest.raises(ValueError): + sl.ParamTable(row_dim=int, columns={}) + + +def test_raises_with_inconsistent_column_sizes() -> None: + with pytest.raises(ValueError): + sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0]}) + + +def test_raises_with_inconsistent_index_length() -> None: + with pytest.raises(ValueError): + sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0]}, index=[1, 2, 3]) + + +def test_raises_with_non_unique_index() -> None: + with pytest.raises(ValueError): + sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[1, 1, 2]) + + +def test_contains_includes_all_columns() -> None: + pt = sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0, 3.0]}) + assert int in pt + assert float in pt + assert str not in pt + + +def test_contains_does_not_include_index() -> None: + pt = sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0, 3.0]}) + assert int not in pt + + +def test_len_is_number_of_columns() -> None: + pt = sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0, 3.0]}) + assert len(pt) == 2 + + +def test_defaults_to_range_index() -> None: + pt = sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0, 3.0]}) + assert pt.index == [0, 1, 2] diff --git a/tests/pipeline_with_param_table_test.py b/tests/pipeline_with_param_table_test.py new file mode 100644 index 00000000..88f8a718 --- /dev/null +++ b/tests/pipeline_with_param_table_test.py @@ -0,0 +1,583 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from typing import List, NewType, Optional, TypeVar + +import pytest + +import sciline as sl +from sciline.typing import Item, Label + + +def test_set_param_table_raises_if_param_names_are_duplicate() -> None: + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) + with pytest.raises(ValueError): + pl.set_param_table(sl.ParamTable(str, {float: [4.0, 5.0, 6.0]})) + assert pl.compute(Item((Label(int, 1),), float)) == 2.0 + with pytest.raises(sl.UnsatisfiedRequirement): + pl.compute(Item((Label(str, 1),), float)) + + +def test_set_param_table_raises_if_row_dim_is_duplicate() -> None: + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) + with pytest.raises(ValueError): + pl.set_param_table(sl.ParamTable(int, {str: ['a', 'b', 'c']})) + assert pl.compute(Item((Label(int, 1),), float)) == 2.0 + with pytest.raises(sl.UnsatisfiedRequirement): + pl.compute(Item((Label(int, 1),), str)) + + +def test_can_get_elements_of_param_table() -> None: + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) + assert pl.compute(Item((Label(int, 1),), float)) == 2.0 + + +def test_can_get_elements_of_param_table_with_explicit_index() -> None: + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[11, 12, 13])) + assert pl.compute(Item((Label(int, 12),), float)) == 2.0 + + +def test_can_depend_on_elements_of_param_table() -> None: + # This is not a valid type annotation, notsure why it works with get_type_hints + def use_elem(x: Item((Label(int, 1),), float)) -> str: # type: ignore[valid-type] + return str(x) + + pl = sl.Pipeline([use_elem]) + pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) + assert pl.compute(str) == "2.0" + + +def test_can_compute_series_of_param_values() -> None: + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) + assert pl.compute(sl.Series[int, float]) == sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0}) + + +def test_can_compute_series_of_derived_values() -> None: + def process(x: float) -> str: + return str(x) + + pl = sl.Pipeline([process]) + pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) + assert pl.compute(sl.Series[int, str]) == sl.Series( + int, {0: "1.0", 1: "2.0", 2: "3.0"} + ) + + +def test_creating_pipeline_with_provider_of_series_raises() -> None: + series = sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0}) + + def make_series() -> sl.Series[int, float]: + return series + + with pytest.raises(ValueError): + sl.Pipeline([make_series]) + + +def test_creating_pipeline_with_series_param_raises() -> None: + series = sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0}) + + with pytest.raises(ValueError): + sl.Pipeline([], params={sl.Series[int, float]: series}) + + +def test_explicit_index_of_param_table_is_forwarded_correctly() -> None: + def process(x: float) -> int: + return int(x) + + pl = sl.Pipeline([process]) + pl.set_param_table( + sl.ParamTable(str, {float: [1.0, 2.0, 3.0]}, index=['a', 'b', 'c']) + ) + assert pl.compute(sl.Series[str, int]) == sl.Series(str, {'a': 1, 'b': 2, 'c': 3}) + + +def test_can_gather_index() -> None: + Sum = NewType("Sum", float) + Name = NewType("Name", str) + + def gather(x: sl.Series[Name, float]) -> Sum: + return Sum(sum(x.values())) + + def make_float(x: str) -> float: + return float(x) + + pl = sl.Pipeline([gather, make_float]) + pl.set_param_table(sl.ParamTable(Name, {str: ["1.0", "2.0", "3.0"]})) + assert pl.compute(Sum) == 6.0 + + +def test_can_zip() -> None: + Sum = NewType("Sum", str) + Str = NewType("Str", str) + Run = NewType("Run", int) + + def gather_zip(x: sl.Series[Run, Str], y: sl.Series[Run, int]) -> Sum: + z = [f'{x_}{y_}' for x_, y_ in zip(x.values(), y.values())] + return Sum(str(z)) + + def use_str(x: str) -> Str: + return Str(x) + + pl = sl.Pipeline([gather_zip, use_str]) + pl.set_param_table(sl.ParamTable(Run, {str: ['a', 'a', 'ccc'], int: [1, 2, 3]})) + + assert pl.compute(Sum) == "['a1', 'a2', 'ccc3']" + + +def test_diamond_dependency_pulls_values_from_columns_in_same_param_table() -> None: + Sum = NewType("Sum", float) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Row = NewType("Row", int) + + def gather(x: sl.Series[Row, float]) -> Sum: + return Sum(sum(x.values())) + + def join(x: Param1, y: Param2) -> float: + return x / y + + pl = sl.Pipeline([gather, join]) + pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 4, 9], Param2: [1, 2, 3]})) + + assert pl.compute(Sum) == Sum(6) + + +def test_diamond_dependency_on_same_column() -> None: + Sum = NewType("Sum", float) + Param = NewType("Param", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Row = NewType("Row", int) + + def gather(x: sl.Series[Row, float]) -> Sum: + return Sum(sum(x.values())) + + def to_param1(x: Param) -> Param1: + return Param1(x) + + def to_param2(x: Param) -> Param2: + return Param2(x) + + def join(x: Param1, y: Param2) -> float: + return x / y + + pl = sl.Pipeline([gather, join, to_param1, to_param2]) + pl.set_param_table(sl.ParamTable(Row, {Param: [1, 2, 3]})) + + assert pl.compute(Sum) == Sum(3) + + +def test_dependencies_on_different_param_tables_broadcast() -> None: + Row1 = NewType("Row1", int) + Row2 = NewType("Row2", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Product = NewType("Product", str) + + def gather_both(x: sl.Series[Row1, Param1], y: sl.Series[Row2, Param2]) -> Product: + broadcast = [[x_, y_] for x_ in x.values() for y_ in y.values()] + return Product(str(broadcast)) + + pl = sl.Pipeline([gather_both]) + pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) + pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) + assert pl.compute(Product) == "[[1, 4], [1, 5], [2, 4], [2, 5], [3, 4], [3, 5]]" + + +def test_dependency_on_other_param_table_in_parent_broadcasts_branch() -> None: + Row1 = NewType("Row1", int) + Row2 = NewType("Row2", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Summed2 = NewType("Summed2", int) + Product = NewType("Product", str) + + def gather2_and_combine(x: Param1, y: sl.Series[Row2, Param2]) -> Summed2: + return Summed2(x * sum(y.values())) + + def gather1(x: sl.Series[Row1, Summed2]) -> Product: + return Product(str(list(x.values()))) + + pl = sl.Pipeline([gather1, gather2_and_combine]) + pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) + pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) + assert pl.compute(Product) == "[9, 18, 27]" + + +def test_dependency_on_other_param_table_in_grandparent_broadcasts_branch() -> None: + Row1 = NewType("Row1", int) + Row2 = NewType("Row2", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Summed2 = NewType("Summed2", int) + Combined = NewType("Combined", int) + Product = NewType("Product", str) + + def gather2(x: sl.Series[Row2, Param2]) -> Summed2: + return Summed2(sum(x.values())) + + def combine(x: Param1, y: Summed2) -> Combined: + return Combined(x * y) + + def gather1(x: sl.Series[Row1, Combined]) -> Product: + return Product(str(list(x.values()))) + + pl = sl.Pipeline([gather1, gather2, combine]) + pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) + pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) + assert pl.compute(Product) == "[9, 18, 27]" + + +def test_nested_dependencies_on_different_param_tables() -> None: + Row1 = NewType("Row1", int) + Row2 = NewType("Row2", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Combined = NewType("Combined", int) + + def combine(x: Param1, y: Param2) -> Combined: + return Combined(x * y) + + pl = sl.Pipeline([combine]) + pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) + pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) + assert pl.compute(sl.Series[Row1, sl.Series[Row2, Combined]]) == sl.Series( + Row1, + { + 0: sl.Series(Row2, {0: 4, 1: 5}), + 1: sl.Series(Row2, {0: 8, 1: 10}), + 2: sl.Series(Row2, {0: 12, 1: 15}), + }, + ) + assert pl.compute(sl.Series[Row2, sl.Series[Row1, Combined]]) == sl.Series( + Row2, + { + 0: sl.Series(Row1, {0: 4, 1: 8, 2: 12}), + 1: sl.Series(Row1, {0: 5, 1: 10, 2: 15}), + }, + ) + + +def test_can_groupby_by_requesting_series_of_series() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]})) + expected = sl.Series( + Param1, + {1: sl.Series(Row, {0: 4, 1: 5}), 3: sl.Series(Row, {2: 6})}, + ) + assert pl.compute(sl.Series[Param1, sl.Series[Row, Param2]]) == expected + + +def test_groupby_by_requesting_series_of_series_preserves_indices() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + + pl = sl.Pipeline() + pl.set_param_table( + sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]}, index=[11, 12, 13]) + ) + assert pl.compute(sl.Series[Param1, sl.Series[Row, Param2]]) == sl.Series( + Param1, {1: sl.Series(Row, {11: 4, 12: 5}), 3: sl.Series(Row, {13: 6})} + ) + + +def test_can_groupby_by_param_used_in_ancestor() -> None: + Row = NewType("Row", int) + Param = NewType("Param", str) + + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(Row, {Param: ['x', 'x', 'y']})) + expected = sl.Series( + Param, + {"x": sl.Series(Row, {0: "x", 1: "x"}), "y": sl.Series(Row, {2: "y"})}, + ) + assert pl.compute(sl.Series[Param, sl.Series[Row, Param]]) == expected + + +def test_multi_level_groupby_raises_with_params_from_same_table() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Param3 = NewType("Param3", int) + + pl = sl.Pipeline() + pl.set_param_table( + sl.ParamTable( + Row, {Param1: [1, 1, 1, 3], Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]} + ) + ) + with pytest.raises(ValueError, match='Could not find unique grouping node'): + pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) + + +def test_multi_level_groupby_with_params_from_different_table() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Param3 = NewType("Param3", int) + + pl = sl.Pipeline() + grouping1 = sl.ParamTable(Row, {Param2: [0, 1, 1, 2], Param3: [7, 8, 9, 10]}) + # We are not providing an explicit index here, so this only happens to work because + # the values of Param2 match range(2). + grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}) + pl.set_param_table(grouping1) + pl.set_param_table(grouping2) + assert pl.compute( + sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]] + ) == sl.Series( + Param1, + { + 1: sl.Series( + Param2, {0: sl.Series(Row, {0: 7}), 1: sl.Series(Row, {1: 8, 2: 9})} + ), + 3: sl.Series(Param2, {2: sl.Series(Row, {3: 10})}), + }, + ) + + +def test_multi_level_groupby_with_params_from_different_table_can_select() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Param3 = NewType("Param3", int) + + pl = sl.Pipeline() + grouping1 = sl.ParamTable(Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}) + # Note the missing index "6" here. + grouping2 = sl.ParamTable(Param2, {Param1: [1, 1]}, index=[4, 5]) + pl.set_param_table(grouping1) + pl.set_param_table(grouping2) + assert pl.compute( + sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]] + ) == sl.Series( + Param1, + { + 1: sl.Series( + Param2, {4: sl.Series(Row, {0: 7}), 5: sl.Series(Row, {1: 8, 2: 9})} + ) + }, + ) + + +def test_multi_level_groupby_with_params_from_different_table_preserves_index() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Param3 = NewType("Param3", int) + + pl = sl.Pipeline() + grouping1 = sl.ParamTable( + Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400] + ) + grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=[4, 5, 6]) + pl.set_param_table(grouping1) + pl.set_param_table(grouping2) + assert pl.compute( + sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]] + ) == sl.Series( + Param1, + { + 1: sl.Series( + Param2, + {4: sl.Series(Row, {100: 7}), 5: sl.Series(Row, {200: 8, 300: 9})}, + ), + 3: sl.Series(Param2, {6: sl.Series(Row, {400: 10})}), + }, + ) + + +def test_multi_level_groupby_with_params_from_different_table_can_reorder() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Param3 = NewType("Param3", int) + + pl = sl.Pipeline() + grouping1 = sl.ParamTable( + Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400] + ) + grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=[6, 5, 4]) + pl.set_param_table(grouping1) + pl.set_param_table(grouping2) + assert pl.compute( + sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]] + ) == sl.Series( + Param1, + { + 1: sl.Series( + Param2, + {6: sl.Series(Row, {400: 10}), 5: sl.Series(Row, {200: 8, 300: 9})}, + ), + 3: sl.Series(Param2, {4: sl.Series(Row, {100: 7})}), + }, + ) + + +@pytest.mark.parametrize("index", [None, [4, 5, 7]]) +def test_multi_level_groupby_raises_on_index_mismatch( + index: Optional[List[int]], +) -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Param3 = NewType("Param3", int) + + pl = sl.Pipeline() + grouping1 = sl.ParamTable( + Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400] + ) + grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=index) + pl.set_param_table(grouping1) + pl.set_param_table(grouping2) + with pytest.raises(ValueError): + pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) + + +@pytest.mark.parametrize("index", [None, [4, 5, 6]]) +def test_groupby_over_param_table(index: Optional[List[int]]) -> None: + Index = NewType("Index", int) + Name = NewType("Name", str) + Param = NewType("Param", int) + ProcessedParam = NewType("ProcessedParam", int) + SummedGroup = NewType("SummedGroup", int) + ProcessedGroup = NewType("ProcessedGroup", int) + + def process_param(x: Param) -> ProcessedParam: + return ProcessedParam(x + 1) + + def sum_group(group: sl.Series[Index, ProcessedParam]) -> SummedGroup: + return SummedGroup(sum(group.values())) + + def process(x: SummedGroup) -> ProcessedGroup: + return ProcessedGroup(2 * x) + + params = sl.ParamTable( + Index, {Param: [1, 2, 3], Name: ['a', 'a', 'b']}, index=index + ) + pl = sl.Pipeline([process_param, sum_group, process]) + pl.set_param_table(params) + + graph = pl.get(sl.Series[Name, ProcessedGroup]) + assert graph.compute() == sl.Series(Name, {'a': 10, 'b': 8}) + + +def test_requesting_series_index_that_is_not_in_param_table_raises() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]})) + + with pytest.raises(KeyError): + pl.compute(sl.Series[int, Param2]) + + +def test_requesting_series_index_that_is_a_param_raises_if_not_grouping() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]})) + with pytest.raises(ValueError, match='Could not find unique grouping node'): + pl.compute(sl.Series[Param1, Param2]) + + +def test_generic_providers_work_with_param_tables() -> None: + Param = TypeVar('Param') + Row = NewType("Row", int) + + class Str(sl.Scope[Param, str], str): + ... + + def parametrized(x: Param) -> Str[Param]: + return Str(f'{x}') + + def make_float() -> float: + return 1.5 + + pipeline = sl.Pipeline([make_float, parametrized]) + pipeline.set_param_table(sl.ParamTable(Row, {int: [1, 2, 3]})) + + assert pipeline.compute(Str[float]) == Str[float]('1.5') + with pytest.raises(sl.UnsatisfiedRequirement): + pipeline.compute(Str[int]) + assert pipeline.compute(sl.Series[Row, Str[int]]) == sl.Series( + Row, + { + 0: Str[int]('1'), + 1: Str[int]('2'), + 2: Str[int]('3'), + }, + ) + + +def test_generic_provider_can_depend_on_param_series() -> None: + Param = TypeVar('Param') + Row = NewType("Row", int) + + class Str(sl.Scope[Param, str], str): + ... + + def parametrized_gather(x: sl.Series[Row, Param]) -> Str[Param]: + return Str(f'{list(x.values())}') + + pipeline = sl.Pipeline([parametrized_gather]) + pipeline.set_param_table( + sl.ParamTable(Row, {int: [1, 2, 3], float: [1.5, 2.5, 3.5]}) + ) + + assert pipeline.compute(Str[int]) == Str[int]('[1, 2, 3]') + assert pipeline.compute(Str[float]) == Str[float]('[1.5, 2.5, 3.5]') + + +def test_generic_provider_can_depend_on_derived_param_series() -> None: + T = TypeVar('T') + Row = NewType("Row", int) + + class Str(sl.Scope[T, str], str): + ... + + def use_param(x: int) -> float: + return x + 0.5 + + def parametrized_gather(x: sl.Series[Row, T]) -> Str[T]: + return Str(f'{list(x.values())}') + + pipeline = sl.Pipeline([parametrized_gather, use_param]) + pipeline.set_param_table(sl.ParamTable(Row, {int: [1, 2, 3]})) + + assert pipeline.compute(Str[float]) == Str[float]('[1.5, 2.5, 3.5]') + + +def test_params_in_table_can_be_generic() -> None: + T = TypeVar('T') + Row = NewType("Row", int) + + class Str(sl.Scope[T, str], str): + ... + + class Param(sl.Scope[T, str], str): + ... + + def parametrized_gather(x: sl.Series[Row, Param[T]]) -> Str[T]: + return Str(','.join(x.values())) + + pipeline = sl.Pipeline([parametrized_gather]) + pipeline.set_param_table( + sl.ParamTable(Row, {Param[int]: ["1", "2"], Param[float]: ["1.5", "2.5"]}) + ) + + assert pipeline.compute(Str[int]) == Str[int]('1,2') + assert pipeline.compute(Str[float]) == Str[float]('1.5,2.5') diff --git a/tests/task_graph_test.py b/tests/task_graph_test.py index 0f376049..31e24714 100644 --- a/tests/task_graph_test.py +++ b/tests/task_graph_test.py @@ -3,8 +3,8 @@ import pytest import sciline as sl -from sciline.scheduler import Graph from sciline.task_graph import TaskGraph +from sciline.typing import Graph def as_float(x: int) -> float: diff --git a/tests/visualize_test.py b/tests/visualize_test.py index 663c6dfe..8290cbc7 100644 --- a/tests/visualize_test.py +++ b/tests/visualize_test.py @@ -30,6 +30,6 @@ class B(Generic[T]): class SubA(A[T]): pass - assert sl.visualize._format_type(A[float]) == 'A[float]' - assert sl.visualize._format_type(SubA[float]) == 'SubA[float]' - assert sl.visualize._format_type(B[float]) == 'B[float]' + assert sl.visualize._format_type(A[float]).name == 'A[float]' + assert sl.visualize._format_type(SubA[float]).name == 'SubA[float]' + assert sl.visualize._format_type(B[float]).name == 'B[float]'