From 0187be18aec2c406761b6f81783b75d416357c05 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 28 Jul 2023 13:01:50 +0200 Subject: [PATCH 01/87] First working demo --- src/sciline/graph.py | 34 ++++ src/sciline/variadic.py | 44 +++++ tests/graph_test.py | 129 +++++++++++++++ tests/variadic_workflow_test.py | 284 ++++++++++++++++++++++++++++++++ 4 files changed, 491 insertions(+) create mode 100644 src/sciline/graph.py create mode 100644 src/sciline/variadic.py create mode 100644 tests/graph_test.py create mode 100644 tests/variadic_workflow_test.py diff --git a/src/sciline/graph.py b/src/sciline/graph.py new file mode 100644 index 00000000..aa735b9a --- /dev/null +++ b/src/sciline/graph.py @@ -0,0 +1,34 @@ +from sciline.pipeline import Key + + +from typing import List + + +def find_path(graph, start, end) -> List[Key]: + """Find a path from start to end in a DAG.""" + if start == end: + return [start] + for node in graph[start]: + path = find_path(graph, node, end) + if path: + return [start] + path + return [] + + +def find_unique_path(graph, start, end) -> List[Key]: + """Find a path from start to end in a DAG. + + Like find_path, but raises if more than one path found + """ + if start == end: + return [start] + if start not in graph: + return [] + paths = [] + for node in graph[start]: + path = find_unique_path(graph, node, end) + if path: + paths.append([start] + path) + if len(paths) > 1: + raise RuntimeError(f"Multiple paths found from {start} to {end}") + return paths[0] if paths else [] diff --git a/src/sciline/variadic.py b/src/sciline/variadic.py new file mode 100644 index 00000000..98bf5fb9 --- /dev/null +++ b/src/sciline/variadic.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +from typing import TypeVar, Iterator, Generic, List + +from collections.abc import Collection, Mapping + +T = TypeVar('T') + + +class Stack(Collection, Generic[T]): + def __init__(self, values: List[T]) -> None: + self._stack: List[T] = values + + def __contains__(self, item: object) -> bool: + return item in self._stack + + def __iter__(self) -> Iterator[T]: + return iter(self._stack) + + def __len__(self) -> int: + return len(self._stack) + + +Key = TypeVar('Key') +Value = TypeVar('Value') + + +class Map(Mapping[Key, Value]): + def __init__(self, values: Mapping[Key, Value]) -> None: + self._map: Mapping[Key, Value] = values + + def __contains__(self, item: object) -> bool: + return item in self._map + + def __iter__(self) -> Iterator[Key]: + return iter(self._map) + + def __len__(self) -> int: + return len(self._map) + + def __getitem__(self, key: Key) -> Value: + return self._map[key] diff --git a/tests/graph_test.py b/tests/graph_test.py new file mode 100644 index 00000000..f64d4af6 --- /dev/null +++ b/tests/graph_test.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +import pytest + +import sciline as sl +from sciline.graph import find_path, find_unique_path + + +def test_find_path(): + graph = {"D": ["B", "C"], "C": ["A"], "B": ["A"]} + assert find_path(graph, "D", "A") == ["D", "B", "A"] + + +def test_find_unique_path(): + graph = {"D": ["B", "C"], "C": ["A"], "B": ["A"]} + with pytest.raises(RuntimeError): + find_unique_path(graph, "D", "A") + graph = {"D": ["B", "C"], "C": ["A"], "B": ["aux"]} + assert find_unique_path(graph, "D", "A") == ["D", "C", "A"] + + +# 1. starting from join point, search dependencies recursively, until +# we find the fork point. +# 2. use find_unique_path to find the path from fork point to join point. +# 3. replace all nodes inside path by nodes mapped of size of fork +# 4. insert special fork and join nodes + +# Do we need to be more explicit in what we combine? +# def combine(images: Multi[Image]) -> float: +# return sum(images) +# ... does not specify the "dimension" we reduce over. +# +# filenames = Structured('subrun', ['x.dat', 'y.dat']) + +from typing import Literal, Generic, TypeVar, List, Iterator +from collections.abc import Collection + +T = TypeVar('T') +DIM = TypeVar('DIM') + + +class Group(Collection[T], Generic[DIM, T]): + def __init__(self, value: List[T]) -> None: + self._stack = value + + def __contains__(self, item: object) -> bool: + return item in self._stack + + def __iter__(self) -> Iterator[T]: + return iter(self._stack) + + def __len__(self) -> int: + return len(self._stack) + + +def test_literal_param() -> None: + X = Literal['x'] + + def fork() -> Group[X, int]: + return Group[X, int]([1, 2, 3]) + + def process(a: int) -> float: + return 0.5 * a + + def join(a: Group[X, float]) -> float: + return sum(a) + + assert join(Group([1.0, 2.0, 3.0])) == 6.0 + + +def test_literal_comp(): + assert Literal['x'] == Literal['x'] + + +from typing import get_type_hints + + +def wrap(func): + arg_types = get_type_hints(func) + + def wrapper(x: List[arg_types['x']]) -> List[arg_types['return']]: + return [func(x_) for x_ in x] + + return wrapper + + +def test_decorator(): + """ + Given a funcion + f(a: A, b: B, c: C, ...) -> Ret: ... + and tps=(B, Ret) return a new function + f(a: A, b: List[B], c: C, ...) -> List[Ret]: ... + """ + + def f(x: int, y: str) -> float: + return 0.5 * x + + assert wrap(f)([1, 2, 3]) == [0.5, 1.0, 1.5] + g = wrap(f) + assert get_type_hints(g)['x'] == List[int] + assert get_type_hints(g)['return'] == List[float] + + +def test_pipeline(): + from typing import NewType + + Filename = NewType('Filename', str) + Data = NewType('Data', float) + Param = NewType('Param', str) + Run = TypeVar('Run') + Raw = NewType('Raw', float) + Clean = NewType('Clean', float) + SampleRun = NewType('SampleRun', int) + + pl = sl.Pipeline() + filenames = ['x.dat', 'y.dat'] + params = ['a', 'b'] + pl.set_mapping_keys(Filename, filenames) + # pl.indices[Filename] = filenames ?? + # pl[Filename, 'x.dat'] # returns new pipeline, restricted to single filename (or range) + + def clean(raw: Raw[Run]) -> Clean[Run]: + return Clean(raw.data) + + def combine(data: sl.Mapping[Filename, Clean[SampleRun]]) -> float: + return sum(data.values()) + + # pipeline has Filename index, so when building graph and looking for + # sl.Mapping[Filename, Data] it will know how many tasks to create. diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py new file mode 100644 index 00000000..5cb95afa --- /dev/null +++ b/tests/variadic_workflow_test.py @@ -0,0 +1,284 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from dataclasses import dataclass +from typing import ( + NewType, + TypeVar, + Generic, + Iterable, + Tuple, + Dict, + Callable, + get_type_hints, + get_origin, + get_args, + List, + Type, + Any, +) + + +import sciline as sl +from sciline.variadic import Stack, Map +from sciline.graph import find_path, find_unique_path + + +T = TypeVar('T') + + +class Multi(Generic[T]): + def __init__(self, values: Iterable[T]) -> None: + self.values = values + + +def test_literal_param() -> None: + Filename = NewType('Filename', str) + Image = NewType('Image', float) + + def read_file(filename: Filename) -> Image: + print(filename) + return Image(float(filename[-1])) + + def combine(images: Multi[Image]) -> float: + return sum(images.values) + + filenames = [f'file{i}' for i in range(10)] + # pipeline = sl.Pipeline([read_file, combine]) + # pipeline[Multi[Filename]] = Multi(filenames) + # assert pipeline.compute(float) == 45.0 + + # -> Multi[Image] not found + # -> look for Image provider + # -> look for Filename provider + # -> look for Multi[Filename] provider + + @dataclass(frozen=True) + class Key: + tp: type + index: int + + # How do write down graph for this? + # - need dummy __getitem__ tasks + # - need dummy to_list tasks + graph = {} + graph[Multi[Filename]] = Multi(filenames) + size = len(filenames) + for i in range(size): + graph[Key(Filename, i)] = (lambda x, j: x.values[j], Multi[Filename], i) + graph[Key(Image, i)] = (read_file, Key(Filename, i)) + graph[Multi[Image]] = ( + lambda *args: Multi(args), + *[Key(Image, i) for i in range(size)], + ) + graph[float] = (combine, Multi[Image]) + + import dask + + assert dask.get(graph, Multi[Filename]).values == filenames + assert dask.get(graph, float) == 45.0 + + +from inspect import getfullargspec, signature, Parameter +from typing import get_type_hints + + +def test_args_hints() -> None: + def f(*myargs: int) -> None: + pass + + args_name = getfullargspec(f)[1] + params = signature(f).parameters + for p in params.values(): + if p.kind == Parameter.VAR_POSITIONAL: + args_name = p.name + args_type = p.annotation + break + assert args_name == 'myargs' + assert args_type == int + # args_name = signature(f).parameters['args'].name + + assert get_type_hints(f)[args_name] == int + + +def test_Stack_dependency_uses_Stack_provider(): + Filename = NewType('Filename', str) + + def combine(names: Stack[Filename]) -> str: + return ';'.join(names) + + filenames = [f'file{i}' for i in range(4)] + pipeline = sl.Pipeline([combine], params={Stack[Filename]: Stack(filenames)}) + assert pipeline.compute(str) == ';'.join(filenames) + + +def test_Stack_dependency_maps_provider_over_Stack_provider(): + Filename = NewType('Filename', str) + Image = NewType('Image', float) + + def read(filename: Filename) -> Image: + return Image(float(filename[-1])) + + def combine(images: Stack[Image]) -> float: + return sum(images) + + filenames = tuple(f'file{i}' for i in range(10)) + pipeline = sl.Pipeline([read, combine], params={Stack[Filename]: Stack(filenames)}) + assert pipeline.compute(float) == 45.0 + + +@dataclass(frozen=True) +class Key: + key: type + tp: type + index: int + + +def build_old( + providers: Dict[type, Callable[..., Any]], + indices: Dict[type, Iterable], + tp: Type[T], +) -> Dict[type, Tuple[Callable[..., Any], ...]]: + graph = {} + stack: List[Type[T]] = [tp] + while stack: + tp = stack.pop() + if get_origin(tp) == Map: + Index, Value = get_args(tp) + + def provider(*values: Value) -> tp: + return tp(dict(zip(indices[Key], values))) + + size = len(indices[Index]) + args = {Value: Key(Index, Value, i) for i in range(size)} + graph[tp] = (provider, *args.values()) + # Is it easier to insert a single task, and branch when inserting into graph? + # might be easier to track dependencies + + # Need: Map[Filename, Image] + # write combining provider and insert into graph + # get Image provider + # add all into graph... args? + value_provider = providers[Value] + args = get_type_hints(value_provider) + del args['return'] + # rewrite args to use Key... btu which args!??? + # this does not work... first have to find path to indices, then rewrite args + # that track to an index with to array of tasks + for i in range(size): + graph[Key(Index, Value, i)] = (value_provider,) + + for arg in args.values(): + if arg not in graph: + stack.append(arg) + # for key, value in graph.items(): + # print(key, value) + elif isinstance(tp, Key): + provider = providers[tp.tp] + size = len(indices[tp.key]) + # graph[tp] = (provider, Key(tp.key, tp.index, i) for i in range(size)]) + # TODO + + elif (provider := providers.get(tp)) is not None: + args = get_type_hints(provider) + del args['return'] + graph[tp] = (provider, *args.values()) + for arg in args.values(): + if arg not in graph: + stack.append(arg) + else: + raise RuntimeError(f'No provider for {tp}') + return graph + + +def build( + providers: Dict[type, Callable[..., Any]], + indices: Dict[type, Iterable], + tp: Type[T], +) -> Dict[type, Tuple[Callable[..., Any], ...]]: + graph = {} + stack: List[Type[T]] = [tp] + while stack: + tp = stack.pop() + if tp in indices: + pass + elif get_origin(tp) == Map: + Index, Value = get_args(tp) + + def provider(*values: Value) -> tp: + return tp(dict(zip(indices[Index], values))) + + size = len(indices[Index]) + args = [Key(Index, Value, i) for i in range(size)] + graph[tp] = (provider, *args) + + subgraph = build(providers, indices, Value) + path = find_unique_path(subgraph, Value, Index) + for key, value in subgraph.items(): + if key in path: + for i in range(size): + provider, *args = value + args = [ + Key(Index, arg, i) if arg in path else arg for arg in args + ] + graph[Key(Index, key, i)] = (provider, *args) + else: + graph[key] = value + for i, index in enumerate(indices[Index]): + graph[Key(Index, Index, i)] = index + elif (provider := providers.get(tp)) is not None: + args = get_type_hints(provider) + del args['return'] + graph[tp] = (provider, *args.values()) + for arg in args.values(): + if arg not in graph: + stack.append(arg) + else: + raise RuntimeError(f'No provider for {tp}') + return graph + + +def test_Map(): + Filename = NewType('Filename', str) + Image = NewType('Image', float) + CleanedImage = NewType('CleanedImage', float) + ScaledImage = NewType('ScaledImage', float) + Param = NewType('Param', float) + + def read(filename: Filename) -> Image: + return Image(float(filename[-1])) + + def clean(x: Image) -> CleanedImage: + return x + + def scale(x: CleanedImage, param: Param) -> ScaledImage: + return x * param + + def combine(images: Map[Filename, ScaledImage]) -> float: + return sum(images.values()) + + def make_int() -> int: + return 2 + + def make_param() -> Param: + return 2.0 + + filenames = tuple(f'file{i}' for i in range(3)) + indices = {Filename: filenames} + providers = { + Image: read, + CleanedImage: clean, + ScaledImage: scale, + float: combine, + int: make_int, + Param: make_param, + } + + import dask + + graph = build(providers, indices, int) + assert dask.get(graph, int) == 2 + graph = build(providers, indices, float) + assert dask.get(graph, float) == 6.0 + from dask.delayed import Delayed + + Delayed(float, graph).visualize(filename='graph.png') From 80d079b1ff1ee95cee96b0c42c87f59350aee4cd Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 28 Jul 2023 13:04:51 +0200 Subject: [PATCH 02/87] Remove unused --- tests/variadic_workflow_test.py | 61 ++------------------------------- 1 file changed, 2 insertions(+), 59 deletions(-) diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py index 5cb95afa..f9db2f03 100644 --- a/tests/variadic_workflow_test.py +++ b/tests/variadic_workflow_test.py @@ -133,63 +133,6 @@ class Key: index: int -def build_old( - providers: Dict[type, Callable[..., Any]], - indices: Dict[type, Iterable], - tp: Type[T], -) -> Dict[type, Tuple[Callable[..., Any], ...]]: - graph = {} - stack: List[Type[T]] = [tp] - while stack: - tp = stack.pop() - if get_origin(tp) == Map: - Index, Value = get_args(tp) - - def provider(*values: Value) -> tp: - return tp(dict(zip(indices[Key], values))) - - size = len(indices[Index]) - args = {Value: Key(Index, Value, i) for i in range(size)} - graph[tp] = (provider, *args.values()) - # Is it easier to insert a single task, and branch when inserting into graph? - # might be easier to track dependencies - - # Need: Map[Filename, Image] - # write combining provider and insert into graph - # get Image provider - # add all into graph... args? - value_provider = providers[Value] - args = get_type_hints(value_provider) - del args['return'] - # rewrite args to use Key... btu which args!??? - # this does not work... first have to find path to indices, then rewrite args - # that track to an index with to array of tasks - for i in range(size): - graph[Key(Index, Value, i)] = (value_provider,) - - for arg in args.values(): - if arg not in graph: - stack.append(arg) - # for key, value in graph.items(): - # print(key, value) - elif isinstance(tp, Key): - provider = providers[tp.tp] - size = len(indices[tp.key]) - # graph[tp] = (provider, Key(tp.key, tp.index, i) for i in range(size)]) - # TODO - - elif (provider := providers.get(tp)) is not None: - args = get_type_hints(provider) - del args['return'] - graph[tp] = (provider, *args.values()) - for arg in args.values(): - if arg not in graph: - stack.append(arg) - else: - raise RuntimeError(f'No provider for {tp}') - return graph - - def build( providers: Dict[type, Callable[..., Any]], indices: Dict[type, Iterable], @@ -262,7 +205,7 @@ def make_int() -> int: def make_param() -> Param: return 2.0 - filenames = tuple(f'file{i}' for i in range(3)) + filenames = tuple(f'file{i}' for i in range(6)) indices = {Filename: filenames} providers = { Image: read, @@ -278,7 +221,7 @@ def make_param() -> Param: graph = build(providers, indices, int) assert dask.get(graph, int) == 2 graph = build(providers, indices, float) - assert dask.get(graph, float) == 6.0 from dask.delayed import Delayed Delayed(float, graph).visualize(filename='graph.png') + assert dask.get(graph, float) == 30.0 From 3031bc533d5c232d2a173667eb616134ef85812e Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 28 Jul 2023 13:26:29 +0200 Subject: [PATCH 03/87] More complex examples (zip!) --- src/sciline/graph.py | 13 +++++++++ tests/variadic_workflow_test.py | 47 ++++++++++++++++++++------------- 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/src/sciline/graph.py b/src/sciline/graph.py index aa735b9a..f1f96e90 100644 --- a/src/sciline/graph.py +++ b/src/sciline/graph.py @@ -32,3 +32,16 @@ def find_unique_path(graph, start, end) -> List[Key]: if len(paths) > 1: raise RuntimeError(f"Multiple paths found from {start} to {end}") return paths[0] if paths else [] + + +def find_all_paths(graph, start, end) -> List[Key]: + """Find all paths from start to end in a DAG.""" + if start == end: + return [[start]] + if start not in graph: + return [] + paths = [] + for node in graph[start]: + for path in find_all_paths(graph, node, end): + paths.append([start] + path) + return paths diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py index f9db2f03..128aec82 100644 --- a/tests/variadic_workflow_test.py +++ b/tests/variadic_workflow_test.py @@ -2,26 +2,24 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from dataclasses import dataclass from typing import ( - NewType, - TypeVar, + Any, + Callable, + Dict, Generic, Iterable, - Tuple, - Dict, - Callable, - get_type_hints, - get_origin, - get_args, List, + NewType, + Tuple, Type, - Any, + TypeVar, + get_args, + get_origin, + get_type_hints, ) - import sciline as sl -from sciline.variadic import Stack, Map -from sciline.graph import find_path, find_unique_path - +from sciline.graph import find_path, find_unique_path, find_all_paths +from sciline.variadic import Map, Stack T = TypeVar('T') @@ -78,7 +76,7 @@ class Key: assert dask.get(graph, float) == 45.0 -from inspect import getfullargspec, signature, Parameter +from inspect import Parameter, getfullargspec, signature from typing import get_type_hints @@ -155,7 +153,11 @@ def provider(*values: Value) -> tp: graph[tp] = (provider, *args) subgraph = build(providers, indices, Value) - path = find_unique_path(subgraph, Value, Index) + paths = find_all_paths(subgraph, Value, Index) + # flatten paths, remove duplicates + path = set() + for p in paths: + path.update(p) for key, value in subgraph.items(): if key in path: for i in range(size): @@ -186,17 +188,25 @@ def test_Map(): CleanedImage = NewType('CleanedImage', float) ScaledImage = NewType('ScaledImage', float) Param = NewType('Param', float) + ImageParam = NewType('ImageParam', float) def read(filename: Filename) -> Image: return Image(float(filename[-1])) - def clean(x: Image) -> CleanedImage: - return x + def image_param(filename: Filename) -> ImageParam: + return ImageParam(sum(ord(c) for c in filename)) + + def clean(x: Image, param: ImageParam) -> CleanedImage: + return x * param def scale(x: CleanedImage, param: Param) -> ScaledImage: return x * param - def combine(images: Map[Filename, ScaledImage]) -> float: + def combine( + images: Map[Filename, ScaledImage], params: Map[Filename, ImageParam] + ) -> float: + print(list(images.values())) + print(list(params.values())) return sum(images.values()) def make_int() -> int: @@ -214,6 +224,7 @@ def make_param() -> Param: float: combine, int: make_int, Param: make_param, + ImageParam: image_param, } import dask From 247674fdc64ab877828e77691169aa660e696e3c Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 28 Jul 2023 13:48:57 +0200 Subject: [PATCH 04/87] Avoid issue with declaring function in loop --- src/sciline/graph.py | 5 +- tests/variadic_workflow_test.py | 130 ++++---------------------------- 2 files changed, 18 insertions(+), 117 deletions(-) diff --git a/src/sciline/graph.py b/src/sciline/graph.py index f1f96e90..bd1575c8 100644 --- a/src/sciline/graph.py +++ b/src/sciline/graph.py @@ -1,8 +1,7 @@ -from sciline.pipeline import Key - - from typing import List +from sciline.pipeline import Key + def find_path(graph, start, end) -> List[Key]: """Find a path from start to end in a DAG.""" diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py index 128aec82..3a4a8abe 100644 --- a/tests/variadic_workflow_test.py +++ b/tests/variadic_workflow_test.py @@ -5,7 +5,6 @@ Any, Callable, Dict, - Generic, Iterable, List, NewType, @@ -17,111 +16,12 @@ get_type_hints, ) -import sciline as sl -from sciline.graph import find_path, find_unique_path, find_all_paths -from sciline.variadic import Map, Stack +import dask -T = TypeVar('T') - - -class Multi(Generic[T]): - def __init__(self, values: Iterable[T]) -> None: - self.values = values - - -def test_literal_param() -> None: - Filename = NewType('Filename', str) - Image = NewType('Image', float) - - def read_file(filename: Filename) -> Image: - print(filename) - return Image(float(filename[-1])) - - def combine(images: Multi[Image]) -> float: - return sum(images.values) - - filenames = [f'file{i}' for i in range(10)] - # pipeline = sl.Pipeline([read_file, combine]) - # pipeline[Multi[Filename]] = Multi(filenames) - # assert pipeline.compute(float) == 45.0 - - # -> Multi[Image] not found - # -> look for Image provider - # -> look for Filename provider - # -> look for Multi[Filename] provider - - @dataclass(frozen=True) - class Key: - tp: type - index: int - - # How do write down graph for this? - # - need dummy __getitem__ tasks - # - need dummy to_list tasks - graph = {} - graph[Multi[Filename]] = Multi(filenames) - size = len(filenames) - for i in range(size): - graph[Key(Filename, i)] = (lambda x, j: x.values[j], Multi[Filename], i) - graph[Key(Image, i)] = (read_file, Key(Filename, i)) - graph[Multi[Image]] = ( - lambda *args: Multi(args), - *[Key(Image, i) for i in range(size)], - ) - graph[float] = (combine, Multi[Image]) - - import dask - - assert dask.get(graph, Multi[Filename]).values == filenames - assert dask.get(graph, float) == 45.0 - - -from inspect import Parameter, getfullargspec, signature -from typing import get_type_hints - - -def test_args_hints() -> None: - def f(*myargs: int) -> None: - pass - - args_name = getfullargspec(f)[1] - params = signature(f).parameters - for p in params.values(): - if p.kind == Parameter.VAR_POSITIONAL: - args_name = p.name - args_type = p.annotation - break - assert args_name == 'myargs' - assert args_type == int - # args_name = signature(f).parameters['args'].name +from sciline.graph import find_all_paths +from sciline.variadic import Map - assert get_type_hints(f)[args_name] == int - - -def test_Stack_dependency_uses_Stack_provider(): - Filename = NewType('Filename', str) - - def combine(names: Stack[Filename]) -> str: - return ';'.join(names) - - filenames = [f'file{i}' for i in range(4)] - pipeline = sl.Pipeline([combine], params={Stack[Filename]: Stack(filenames)}) - assert pipeline.compute(str) == ';'.join(filenames) - - -def test_Stack_dependency_maps_provider_over_Stack_provider(): - Filename = NewType('Filename', str) - Image = NewType('Image', float) - - def read(filename: Filename) -> Image: - return Image(float(filename[-1])) - - def combine(images: Stack[Image]) -> float: - return sum(images) - - filenames = tuple(f'file{i}' for i in range(10)) - pipeline = sl.Pipeline([read, combine], params={Stack[Filename]: Stack(filenames)}) - assert pipeline.compute(float) == 45.0 +T = TypeVar('T') @dataclass(frozen=True) @@ -131,6 +31,13 @@ class Key: index: int +def _make_mapping_provider(values, Value, tp, Index): + def provider(*args: Value) -> tp: + return tp(dict(zip(values[Index], args))) + + return provider + + def build( providers: Dict[type, Callable[..., Any]], indices: Dict[type, Iterable], @@ -144,17 +51,13 @@ def build( pass elif get_origin(tp) == Map: Index, Value = get_args(tp) - - def provider(*values: Value) -> tp: - return tp(dict(zip(indices[Index], values))) - size = len(indices[Index]) + provider = _make_mapping_provider(indices, Value, tp, Index) args = [Key(Index, Value, i) for i in range(size)] graph[tp] = (provider, *args) subgraph = build(providers, indices, Value) paths = find_all_paths(subgraph, Value, Index) - # flatten paths, remove duplicates path = set() for p in paths: path.update(p) @@ -196,17 +99,18 @@ def read(filename: Filename) -> Image: def image_param(filename: Filename) -> ImageParam: return ImageParam(sum(ord(c) for c in filename)) - def clean(x: Image, param: ImageParam) -> CleanedImage: + def clean2(x: Image, param: ImageParam) -> CleanedImage: return x * param + def clean(x: Image) -> CleanedImage: + return x + def scale(x: CleanedImage, param: Param) -> ScaledImage: return x * param def combine( images: Map[Filename, ScaledImage], params: Map[Filename, ImageParam] ) -> float: - print(list(images.values())) - print(list(params.values())) return sum(images.values()) def make_int() -> int: @@ -227,8 +131,6 @@ def make_param() -> Param: ImageParam: image_param, } - import dask - graph = build(providers, indices, int) assert dask.get(graph, int) == 2 graph = build(providers, indices, float) From 6e4a507d015b68b055f711c39c9bdaa7610ab1b7 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 28 Jul 2023 13:52:49 +0200 Subject: [PATCH 05/87] Cleanup --- src/sciline/variadic.py | 22 +------- tests/graph_test.py | 111 ---------------------------------------- 2 files changed, 2 insertions(+), 131 deletions(-) diff --git a/src/sciline/variadic.py b/src/sciline/variadic.py index 98bf5fb9..d635cd53 100644 --- a/src/sciline/variadic.py +++ b/src/sciline/variadic.py @@ -2,26 +2,8 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from __future__ import annotations -from typing import TypeVar, Iterator, Generic, List - -from collections.abc import Collection, Mapping - -T = TypeVar('T') - - -class Stack(Collection, Generic[T]): - def __init__(self, values: List[T]) -> None: - self._stack: List[T] = values - - def __contains__(self, item: object) -> bool: - return item in self._stack - - def __iter__(self) -> Iterator[T]: - return iter(self._stack) - - def __len__(self) -> int: - return len(self._stack) - +from collections.abc import Mapping +from typing import Iterator, TypeVar Key = TypeVar('Key') Value = TypeVar('Value') diff --git a/tests/graph_test.py b/tests/graph_test.py index f64d4af6..0b3b2848 100644 --- a/tests/graph_test.py +++ b/tests/graph_test.py @@ -2,7 +2,6 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) import pytest -import sciline as sl from sciline.graph import find_path, find_unique_path @@ -17,113 +16,3 @@ def test_find_unique_path(): find_unique_path(graph, "D", "A") graph = {"D": ["B", "C"], "C": ["A"], "B": ["aux"]} assert find_unique_path(graph, "D", "A") == ["D", "C", "A"] - - -# 1. starting from join point, search dependencies recursively, until -# we find the fork point. -# 2. use find_unique_path to find the path from fork point to join point. -# 3. replace all nodes inside path by nodes mapped of size of fork -# 4. insert special fork and join nodes - -# Do we need to be more explicit in what we combine? -# def combine(images: Multi[Image]) -> float: -# return sum(images) -# ... does not specify the "dimension" we reduce over. -# -# filenames = Structured('subrun', ['x.dat', 'y.dat']) - -from typing import Literal, Generic, TypeVar, List, Iterator -from collections.abc import Collection - -T = TypeVar('T') -DIM = TypeVar('DIM') - - -class Group(Collection[T], Generic[DIM, T]): - def __init__(self, value: List[T]) -> None: - self._stack = value - - def __contains__(self, item: object) -> bool: - return item in self._stack - - def __iter__(self) -> Iterator[T]: - return iter(self._stack) - - def __len__(self) -> int: - return len(self._stack) - - -def test_literal_param() -> None: - X = Literal['x'] - - def fork() -> Group[X, int]: - return Group[X, int]([1, 2, 3]) - - def process(a: int) -> float: - return 0.5 * a - - def join(a: Group[X, float]) -> float: - return sum(a) - - assert join(Group([1.0, 2.0, 3.0])) == 6.0 - - -def test_literal_comp(): - assert Literal['x'] == Literal['x'] - - -from typing import get_type_hints - - -def wrap(func): - arg_types = get_type_hints(func) - - def wrapper(x: List[arg_types['x']]) -> List[arg_types['return']]: - return [func(x_) for x_ in x] - - return wrapper - - -def test_decorator(): - """ - Given a funcion - f(a: A, b: B, c: C, ...) -> Ret: ... - and tps=(B, Ret) return a new function - f(a: A, b: List[B], c: C, ...) -> List[Ret]: ... - """ - - def f(x: int, y: str) -> float: - return 0.5 * x - - assert wrap(f)([1, 2, 3]) == [0.5, 1.0, 1.5] - g = wrap(f) - assert get_type_hints(g)['x'] == List[int] - assert get_type_hints(g)['return'] == List[float] - - -def test_pipeline(): - from typing import NewType - - Filename = NewType('Filename', str) - Data = NewType('Data', float) - Param = NewType('Param', str) - Run = TypeVar('Run') - Raw = NewType('Raw', float) - Clean = NewType('Clean', float) - SampleRun = NewType('SampleRun', int) - - pl = sl.Pipeline() - filenames = ['x.dat', 'y.dat'] - params = ['a', 'b'] - pl.set_mapping_keys(Filename, filenames) - # pl.indices[Filename] = filenames ?? - # pl[Filename, 'x.dat'] # returns new pipeline, restricted to single filename (or range) - - def clean(raw: Raw[Run]) -> Clean[Run]: - return Clean(raw.data) - - def combine(data: sl.Mapping[Filename, Clean[SampleRun]]) -> float: - return sum(data.values()) - - # pipeline has Filename index, so when building graph and looking for - # sl.Mapping[Filename, Data] it will know how many tasks to create. From e8a28eca0bdb868ac6e6087a9793bc77c9b4b596 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 31 Jul 2023 07:27:41 +0200 Subject: [PATCH 06/87] Add "product" example --- tests/variadic_workflow_test.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py index 3a4a8abe..b8bc41a0 100644 --- a/tests/variadic_workflow_test.py +++ b/tests/variadic_workflow_test.py @@ -87,11 +87,13 @@ def build( def test_Map(): Filename = NewType('Filename', str) + Config = NewType('Config', int) Image = NewType('Image', float) CleanedImage = NewType('CleanedImage', float) ScaledImage = NewType('ScaledImage', float) Param = NewType('Param', float) ImageParam = NewType('ImageParam', float) + Result = NewType('Result', float) def read(filename: Filename) -> Image: return Image(float(filename[-1])) @@ -105,14 +107,20 @@ def clean2(x: Image, param: ImageParam) -> CleanedImage: def clean(x: Image) -> CleanedImage: return x - def scale(x: CleanedImage, param: Param) -> ScaledImage: - return x * param + def scale(x: CleanedImage, param: Param, config: Config) -> ScaledImage: + return x * param + config - def combine( + def combine_old( images: Map[Filename, ScaledImage], params: Map[Filename, ImageParam] ) -> float: return sum(images.values()) + def combine(images: Map[Filename, ScaledImage]) -> float: + return sum(images.values()) + + def combine_configs(x: Map[Config, float]) -> Result: + return Result(sum(x.values())) + def make_int() -> int: return 2 @@ -120,7 +128,8 @@ def make_param() -> Param: return 2.0 filenames = tuple(f'file{i}' for i in range(6)) - indices = {Filename: filenames} + configs = tuple(range(2)) + indices = {Filename: filenames, Config: configs} providers = { Image: read, CleanedImage: clean, @@ -129,12 +138,13 @@ def make_param() -> Param: int: make_int, Param: make_param, ImageParam: image_param, + Result: combine_configs, } graph = build(providers, indices, int) assert dask.get(graph, int) == 2 - graph = build(providers, indices, float) + graph = build(providers, indices, Result) from dask.delayed import Delayed - Delayed(float, graph).visualize(filename='graph.png') - assert dask.get(graph, float) == 30.0 + Delayed(Result, graph).visualize(filename='graph.png') + assert dask.get(graph, Result) == 66.0 From 3976ab43eafe04faea6342e36ab59924ef11c2c9 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 31 Jul 2023 08:10:56 +0200 Subject: [PATCH 07/87] Do not use keyword args --- src/sciline/pipeline.py | 8 ++++---- src/sciline/scheduler.py | 9 ++++----- src/sciline/visualize.py | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 4dbd104f..2bc0c270 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -214,13 +214,13 @@ def build(self, tp: Type[T], /) -> Graph: provider: Callable[..., T] provider, bound = self._get_provider(tp) tps = get_type_hints(provider) - args = { - name: _bind_free_typevars(t, bound=bound) + args = tuple( + _bind_free_typevars(t, bound=bound) for name, t in tps.items() if name != 'return' - } + ) graph[tp] = (provider, args) - for arg in args.values(): + for arg in args: if arg not in graph: stack.append(arg) return graph diff --git a/src/sciline/scheduler.py b/src/sciline/scheduler.py index c1b0c886..963e1e8a 100644 --- a/src/sciline/scheduler.py +++ b/src/sciline/scheduler.py @@ -5,7 +5,7 @@ Key = type Graph = Dict[ Key, - Tuple[Callable[..., Any], Dict[str, Key]], + Tuple[Callable[..., Any], Tuple[Key, ...]], ] @@ -40,7 +40,7 @@ class NaiveScheduler: def get(self, graph: Graph, keys: List[type]) -> Any: import graphlib - dependencies = {tp: set(args.values()) for tp, (_, args) in graph.items()} + dependencies = {tp: args for tp, (_, args) in graph.items()} ts = graphlib.TopologicalSorter(dependencies) try: # Create list from generator to force early exception if there is a cycle @@ -50,8 +50,7 @@ def get(self, graph: Graph, keys: List[type]) -> Any: results: Dict[type, Any] = {} for t in tasks: provider, args = graph[t] - args = {name: results[arg] for name, arg in args.items()} - results[t] = provider(**args) + results[t] = provider(*[results[arg] for arg in args]) return tuple(results[key] for key in keys) @@ -78,7 +77,7 @@ def __init__(self, scheduler: Optional[Callable[..., Any]] = None) -> None: self._dask_get = scheduler def get(self, graph: Graph, keys: List[type]) -> Any: - dsk = {tp: (provider, *args.values()) for tp, (provider, args) in graph.items()} + dsk = {tp: (provider, *args) for tp, (provider, args) in graph.items()} try: return self._dask_get(dsk, keys) except RuntimeError as e: diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index 32ed0709..a16eb930 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -36,7 +36,7 @@ def _format_graph(graph: Graph) -> Dict[str, Tuple[str, List[str], str]]: return { _format_provider(provider, ret): ( provider.__qualname__, - [_format_type(a) for a in args.values()], + [_format_type(a) for a in args], _format_type(ret), ) for ret, (provider, args) in graph.items() From 6281f164558b144716b8e9301069f3f245f38052 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 31 Jul 2023 09:10:13 +0200 Subject: [PATCH 08/87] Prepare for Pipeline integration --- src/sciline/graph.py | 11 +++++- tests/variadic_workflow_test.py | 61 ++++++++++++++++++++------------- 2 files changed, 47 insertions(+), 25 deletions(-) diff --git a/src/sciline/graph.py b/src/sciline/graph.py index bd1575c8..e441ed25 100644 --- a/src/sciline/graph.py +++ b/src/sciline/graph.py @@ -40,7 +40,16 @@ def find_all_paths(graph, start, end) -> List[Key]: if start not in graph: return [] paths = [] - for node in graph[start]: + # 0 is the provider, 1 is the args + for node in graph[start][1]: for path in find_all_paths(graph, node, end): paths.append([start] + path) return paths + + +def find_nodes_in_paths(graph, start: Key, end: Key) -> List[Key]: + paths = find_all_paths(graph, start, end) + nodes = set() + for path in paths: + nodes.update(path) + return list(nodes) diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py index b8bc41a0..bfc63bca 100644 --- a/tests/variadic_workflow_test.py +++ b/tests/variadic_workflow_test.py @@ -16,9 +16,8 @@ get_type_hints, ) -import dask - -from sciline.graph import find_all_paths +from sciline.graph import find_nodes_in_paths +from sciline.task_graph import TaskGraph from sciline.variadic import Map T = TypeVar('T') @@ -26,7 +25,7 @@ @dataclass(frozen=True) class Key: - key: type + label: type tp: type index: int @@ -38,6 +37,13 @@ def provider(*args: Value) -> tp: return provider +def _make_instance_provider(value): + def provider() -> type(value): + return value + + return provider + + def build( providers: Dict[type, Callable[..., Any]], indices: Dict[type, Iterable], @@ -50,33 +56,30 @@ def build( if tp in indices: pass elif get_origin(tp) == Map: - Index, Value = get_args(tp) - size = len(indices[Index]) - provider = _make_mapping_provider(indices, Value, tp, Index) - args = [Key(Index, Value, i) for i in range(size)] - graph[tp] = (provider, *args) + Label, Value = get_args(tp) + size = len(indices[Label]) + provider = _make_mapping_provider(indices, Value, tp, Label) + args = [Key(Label, Value, i) for i in range(size)] + graph[tp] = (provider, args) subgraph = build(providers, indices, Value) - paths = find_all_paths(subgraph, Value, Index) - path = set() - for p in paths: - path.update(p) + path = find_nodes_in_paths(subgraph, Value, Label) for key, value in subgraph.items(): if key in path: for i in range(size): - provider, *args = value + provider, args = value args = [ - Key(Index, arg, i) if arg in path else arg for arg in args + Key(Label, arg, i) if arg in path else arg for arg in args ] - graph[Key(Index, key, i)] = (provider, *args) + graph[Key(Label, key, i)] = (provider, args) else: graph[key] = value - for i, index in enumerate(indices[Index]): - graph[Key(Index, Index, i)] = index + for i, label in enumerate(indices[Label]): + graph[Key(Label, Label, i)] = (_make_instance_provider(label), ()) elif (provider := providers.get(tp)) is not None: args = get_type_hints(provider) del args['return'] - graph[tp] = (provider, *args.values()) + graph[tp] = (provider, tuple(args.values())) for arg in args.values(): if arg not in graph: stack.append(arg) @@ -85,6 +88,15 @@ def build( return graph +def get( + providers: Dict[type, Callable[..., Any]], + indices: Dict[type, Iterable], + tp: Type[T], +) -> TaskGraph: + graph = build(providers, indices, tp) + return TaskGraph(graph=graph, keys=tp) + + def test_Map(): Filename = NewType('Filename', str) Config = NewType('Config', int) @@ -141,10 +153,11 @@ def make_param() -> Param: Result: combine_configs, } - graph = build(providers, indices, int) - assert dask.get(graph, int) == 2 - graph = build(providers, indices, Result) + graph = get(providers, indices, int) + assert graph.compute() == 2 + graph = get(providers, indices, Result) + assert graph.compute() == 66.0 + # graph.visualize().render('graph', format='png') from dask.delayed import Delayed - Delayed(Result, graph).visualize(filename='graph.png') - assert dask.get(graph, Result) == 66.0 + Delayed(Result, graph._graph).visualize(filename='graph.png') From 2634ee11fae3d012993e89dca4d0e705fa3eede1 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 31 Jul 2023 10:22:25 +0200 Subject: [PATCH 09/87] Add support for indices to Pipeline --- src/sciline/__init__.py | 2 ++ src/sciline/graph.py | 12 +++---- src/sciline/pipeline.py | 61 +++++++++++++++++++++++++++++++++ tests/variadic_pipeline_test.py | 41 ++++++++++++++++++++++ tests/variadic_workflow_test.py | 5 +-- 5 files changed, 111 insertions(+), 10 deletions(-) create mode 100644 tests/variadic_pipeline_test.py diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index 2f43d450..79dc10f0 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -16,9 +16,11 @@ UnboundTypeVar, UnsatisfiedRequirement, ) +from .variadic import Map __all__ = [ "AmbiguousProvider", + "Map", "Pipeline", "Scope", "UnboundTypeVar", diff --git a/src/sciline/graph.py b/src/sciline/graph.py index e441ed25..2b2dfded 100644 --- a/src/sciline/graph.py +++ b/src/sciline/graph.py @@ -1,9 +1,9 @@ -from typing import List +from typing import List, TypeVar -from sciline.pipeline import Key +T = TypeVar("T") -def find_path(graph, start, end) -> List[Key]: +def find_path(graph, start: T, end: T) -> List[T]: """Find a path from start to end in a DAG.""" if start == end: return [start] @@ -14,7 +14,7 @@ def find_path(graph, start, end) -> List[Key]: return [] -def find_unique_path(graph, start, end) -> List[Key]: +def find_unique_path(graph, start: T, end: T) -> List[T]: """Find a path from start to end in a DAG. Like find_path, but raises if more than one path found @@ -33,7 +33,7 @@ def find_unique_path(graph, start, end) -> List[Key]: return paths[0] if paths else [] -def find_all_paths(graph, start, end) -> List[Key]: +def find_all_paths(graph, start: T, end: T) -> List[List[T]]: """Find all paths from start to end in a DAG.""" if start == end: return [[start]] @@ -47,7 +47,7 @@ def find_all_paths(graph, start, end) -> List[Key]: return paths -def find_nodes_in_paths(graph, start: Key, end: Key) -> List[Key]: +def find_nodes_in_paths(graph, start: T, end: T) -> List[T]: paths = find_all_paths(graph, start, end) nodes = set() for path in paths: diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 2bc0c270..bd6a05f5 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -2,10 +2,12 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from __future__ import annotations +from dataclasses import dataclass from typing import ( Any, Callable, Dict, + Iterable, List, Optional, Tuple, @@ -20,7 +22,9 @@ from sciline.task_graph import TaskGraph from .domain import Scope +from .graph import find_nodes_in_paths from .scheduler import Graph, Scheduler +from .variadic import Map T = TypeVar('T') @@ -74,6 +78,23 @@ def _bind_free_typevars(tp: TypeVar | Key, bound: Dict[TypeVar, Key]) -> Key: return tp +@dataclass(frozen=True) +class Label: + tp: type + index: int + + +@dataclass(frozen=True) +class Item: + label: label + tp: type + + +def _indexed_key(index_name, i, value_name): + label = Label(index_name, i) + return label if index_name == value_name else Item(label, value_name) + + class Pipeline: """A container for providers that can be assembled into a task graph.""" @@ -96,6 +117,7 @@ def __init__( """ self._providers: Dict[Key, Provider] = {} self._subproviders: Dict[type, Dict[Tuple[Key | TypeVar, ...], Provider]] = {} + self._indices: Dict[Key, Iterable[Any]] = {} for provider in providers or []: self.insert(provider) for tp, param in (params or {}).items(): @@ -153,6 +175,11 @@ def __setitem__(self, key: Type[T], param: T) -> None: ) self._set_provider(key, lambda: param) + def set_index(self, key: Type[T], index: Iterable[T]) -> None: + self._indices[key] = index + for i, label in enumerate(index): + self._set_provider(Label(tp=key, index=i), lambda label=label: label) + def _set_provider(self, key: Type[T], provider: Callable[..., T]) -> None: # isinstance does not work here and types.NoneType available only in 3.10+ if key == type(None): # noqa: E721 @@ -192,6 +219,12 @@ def _get_provider(self, tp: Type[T]) -> Tuple[Callable[..., T], Dict[TypeVar, Ke raise AmbiguousProvider("Multiple providers found for type", tp) raise UnsatisfiedRequirement("No provider found for type", tp) + def _make_mapping_provider(self, value_type, mapping_type, index_name): + def provider(*args: value_type) -> mapping_type: + return mapping_type(dict(zip(self._indices[index_name], args))) + + return provider + def build(self, tp: Type[T], /) -> Graph: """ Return a dict of providers required for building the requested type `tp`. @@ -211,6 +244,34 @@ def build(self, tp: Type[T], /) -> Graph: stack: List[Type[T]] = [tp] while stack: tp = stack.pop() + if tp in self._indices: + continue + if get_origin(tp) == Map: + index_name, value_type = get_args(tp) + size = len(self._indices[index_name]) + provider = self._make_mapping_provider( + value_type=value_type, mapping_type=tp, index_name=index_name + ) + args = [Item(Label(index_name, i), value_type) for i in range(size)] + graph[tp] = (provider, args) + + subgraph = self.build(value_type) + path = find_nodes_in_paths(subgraph, value_type, index_name) + for key, value in subgraph.items(): + if key in path: + for i in range(size): + provider, args = value + args = tuple( + _indexed_key(index_name, i, arg) if arg in path else arg + for arg in args + ) + graph[_indexed_key(index_name, i, key)] = (provider, args) + else: + graph[key] = value + for i in range(len(self._indices[index_name])): + provider, _ = self._get_provider(Label(index_name, i)) + graph[Label(index_name, i)] = (provider, ()) + continue provider: Callable[..., T] provider, bound = self._get_provider(tp) tps = get_type_hints(provider) diff --git a/tests/variadic_pipeline_test.py b/tests/variadic_pipeline_test.py new file mode 100644 index 00000000..fce3750c --- /dev/null +++ b/tests/variadic_pipeline_test.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from typing import NewType + +import sciline as sl +from sciline.pipeline import Label + + +def test_set_index_sets_up_providers_with_indexable_instances(): + pl = sl.Pipeline() + pl.set_index(float, [1.0, 2.0, 3.0]) + assert pl.compute(Label(float, 0)) == 1.0 + assert pl.compute(Label(float, 1)) == 2.0 + assert pl.compute(Label(float, 2)) == 3.0 + + +def test_can_depend_on_index_elements(): + def use_index_elem(x: Label(float, 1)) -> int: + return int(x) + + pl = sl.Pipeline([use_index_elem]) + pl.set_index(float, [1.0, 2.0, 3.0]) + assert pl.compute(int) == 2 + + +def test_can_gather_index(): + Sum = NewType("Sum", float) + + def gather(x: sl.Map[str, float]) -> Sum: + print(list(x.items())) + return Sum(sum(x.values())) + + def make_float(x: str) -> float: + return float(x) + + pl = sl.Pipeline([gather, make_float]) + pl.set_index(str, ["1.0", "2.0", "3.0"]) + graph = pl.build(Sum) + for key, value in graph.items(): + print(key, value) + assert pl.compute(Sum) == 6.0 diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py index bfc63bca..9d5fb3aa 100644 --- a/tests/variadic_workflow_test.py +++ b/tests/variadic_workflow_test.py @@ -38,10 +38,7 @@ def provider(*args: Value) -> tp: def _make_instance_provider(value): - def provider() -> type(value): - return value - - return provider + return lambda: value def build( From 9b1b6be74b035e7386329c7e7ac6fe0eba56c58f Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 31 Jul 2023 10:28:01 +0200 Subject: [PATCH 10/87] Cleanup --- src/sciline/pipeline.py | 65 ++++++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index bd6a05f5..1bbbd26f 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -219,12 +219,6 @@ def _get_provider(self, tp: Type[T]) -> Tuple[Callable[..., T], Dict[TypeVar, Ke raise AmbiguousProvider("Multiple providers found for type", tp) raise UnsatisfiedRequirement("No provider found for type", tp) - def _make_mapping_provider(self, value_type, mapping_type, index_name): - def provider(*args: value_type) -> mapping_type: - return mapping_type(dict(zip(self._indices[index_name], args))) - - return provider - def build(self, tp: Type[T], /) -> Graph: """ Return a dict of providers required for building the requested type `tp`. @@ -247,30 +241,7 @@ def build(self, tp: Type[T], /) -> Graph: if tp in self._indices: continue if get_origin(tp) == Map: - index_name, value_type = get_args(tp) - size = len(self._indices[index_name]) - provider = self._make_mapping_provider( - value_type=value_type, mapping_type=tp, index_name=index_name - ) - args = [Item(Label(index_name, i), value_type) for i in range(size)] - graph[tp] = (provider, args) - - subgraph = self.build(value_type) - path = find_nodes_in_paths(subgraph, value_type, index_name) - for key, value in subgraph.items(): - if key in path: - for i in range(size): - provider, args = value - args = tuple( - _indexed_key(index_name, i, arg) if arg in path else arg - for arg in args - ) - graph[_indexed_key(index_name, i, key)] = (provider, args) - else: - graph[key] = value - for i in range(len(self._indices[index_name])): - provider, _ = self._get_provider(Label(index_name, i)) - graph[Label(index_name, i)] = (provider, ()) + graph.update(self._build_indexed_subgraph(tp)) continue provider: Callable[..., T] provider, bound = self._get_provider(tp) @@ -286,6 +257,40 @@ def build(self, tp: Type[T], /) -> Graph: stack.append(arg) return graph + def _build_indexed_subgraph(self, tp: Type[T]) -> Graph: + index_name, value_type = get_args(tp) + size = len(self._indices[index_name]) + provider = self._make_mapping_provider( + value_type=value_type, mapping_type=tp, index_name=index_name + ) + args = [_indexed_key(index_name, i, value_type) for i in range(size)] + graph: Graph = {} + graph[tp] = (provider, args) + + subgraph = self.build(value_type) + path = find_nodes_in_paths(subgraph, value_type, index_name) + for key, value in subgraph.items(): + if key in path: + for i in range(size): + provider, args = value + args = tuple( + _indexed_key(index_name, i, arg) if arg in path else arg + for arg in args + ) + graph[_indexed_key(index_name, i, key)] = (provider, args) + else: + graph[key] = value + for i in range(len(self._indices[index_name])): + provider, _ = self._get_provider(Label(index_name, i)) + graph[Label(index_name, i)] = (provider, ()) + return graph + + def _make_mapping_provider(self, value_type, mapping_type, index_name): + def provider(*args: value_type) -> mapping_type: + return mapping_type(dict(zip(self._indices[index_name], args))) + + return provider + @overload def compute(self, tp: Type[T]) -> T: ... From 7167a43cc01c404902093961ff2f90a58a66ec49 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 31 Jul 2023 10:32:52 +0200 Subject: [PATCH 11/87] Remove replcaed prototype --- tests/variadic_workflow_test.py | 132 ++++++-------------------------- 1 file changed, 23 insertions(+), 109 deletions(-) diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py index 9d5fb3aa..30e116f5 100644 --- a/tests/variadic_workflow_test.py +++ b/tests/variadic_workflow_test.py @@ -1,97 +1,8 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from dataclasses import dataclass -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - NewType, - Tuple, - Type, - TypeVar, - get_args, - get_origin, - get_type_hints, -) - -from sciline.graph import find_nodes_in_paths -from sciline.task_graph import TaskGraph -from sciline.variadic import Map - -T = TypeVar('T') - - -@dataclass(frozen=True) -class Key: - label: type - tp: type - index: int - - -def _make_mapping_provider(values, Value, tp, Index): - def provider(*args: Value) -> tp: - return tp(dict(zip(values[Index], args))) - - return provider - - -def _make_instance_provider(value): - return lambda: value - - -def build( - providers: Dict[type, Callable[..., Any]], - indices: Dict[type, Iterable], - tp: Type[T], -) -> Dict[type, Tuple[Callable[..., Any], ...]]: - graph = {} - stack: List[Type[T]] = [tp] - while stack: - tp = stack.pop() - if tp in indices: - pass - elif get_origin(tp) == Map: - Label, Value = get_args(tp) - size = len(indices[Label]) - provider = _make_mapping_provider(indices, Value, tp, Label) - args = [Key(Label, Value, i) for i in range(size)] - graph[tp] = (provider, args) - - subgraph = build(providers, indices, Value) - path = find_nodes_in_paths(subgraph, Value, Label) - for key, value in subgraph.items(): - if key in path: - for i in range(size): - provider, args = value - args = [ - Key(Label, arg, i) if arg in path else arg for arg in args - ] - graph[Key(Label, key, i)] = (provider, args) - else: - graph[key] = value - for i, label in enumerate(indices[Label]): - graph[Key(Label, Label, i)] = (_make_instance_provider(label), ()) - elif (provider := providers.get(tp)) is not None: - args = get_type_hints(provider) - del args['return'] - graph[tp] = (provider, tuple(args.values())) - for arg in args.values(): - if arg not in graph: - stack.append(arg) - else: - raise RuntimeError(f'No provider for {tp}') - return graph - - -def get( - providers: Dict[type, Callable[..., Any]], - indices: Dict[type, Iterable], - tp: Type[T], -) -> TaskGraph: - graph = build(providers, indices, tp) - return TaskGraph(graph=graph, keys=tp) +from typing import NewType + +import sciline as sl def test_Map(): @@ -120,14 +31,14 @@ def scale(x: CleanedImage, param: Param, config: Config) -> ScaledImage: return x * param + config def combine_old( - images: Map[Filename, ScaledImage], params: Map[Filename, ImageParam] + images: sl.Map[Filename, ScaledImage], params: sl.Map[Filename, ImageParam] ) -> float: return sum(images.values()) - def combine(images: Map[Filename, ScaledImage]) -> float: + def combine(images: sl.Map[Filename, ScaledImage]) -> float: return sum(images.values()) - def combine_configs(x: Map[Config, float]) -> Result: + def combine_configs(x: sl.Map[Config, float]) -> Result: return Result(sum(x.values())) def make_int() -> int: @@ -138,21 +49,24 @@ def make_param() -> Param: filenames = tuple(f'file{i}' for i in range(6)) configs = tuple(range(2)) - indices = {Filename: filenames, Config: configs} - providers = { - Image: read, - CleanedImage: clean, - ScaledImage: scale, - float: combine, - int: make_int, - Param: make_param, - ImageParam: image_param, - Result: combine_configs, - } - - graph = get(providers, indices, int) + pipeline = sl.Pipeline( + [ + read, + clean, + scale, + combine, + combine_configs, + make_int, + make_param, + image_param, + ] + ) + pipeline.set_index(Filename, filenames) + pipeline.set_index(Config, configs) + + graph = pipeline.get(int) assert graph.compute() == 2 - graph = get(providers, indices, Result) + graph = pipeline.get(Result) assert graph.compute() == 66.0 # graph.visualize().render('graph', format='png') from dask.delayed import Delayed From 6313d03f223734117f966c4c2c4bf239a9cd2a45 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 31 Jul 2023 10:59:29 +0200 Subject: [PATCH 12/87] How to visualize variadic graphs? --- src/sciline/pipeline.py | 7 +++++++ src/sciline/visualize.py | 6 +++++- tests/variadic_workflow_test.py | 5 +++-- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 1bbbd26f..392eac3e 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -83,12 +83,19 @@ class Label: tp: type index: int + def __str__(self) -> str: + return f'{self.tp.__name__}[{self.index}]' + @dataclass(frozen=True) class Item: label: label tp: type + def __str__(self) -> str: + name = self.tp.__name__ if hasattr(self.tp, '__name__') else str(self.tp) + return f'{name}[{self.label}]' + def _indexed_key(index_name, i, value_name): label = Label(index_name, i) diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index a16eb930..2ee194ab 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -22,7 +22,10 @@ def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph: for p, (p_name, args, ret) in _format_graph(graph).items(): dot.node(ret, ret, shape='rectangle') # Do not draw dummy providers created by Pipeline when setting instances - if p_name == 'Pipeline.__setitem__..': + if p_name in ( + 'Pipeline.__setitem__..', + 'Pipeline.set_index..', + ): continue dot.node(p, p_name, shape='ellipse') for arg in args: @@ -57,6 +60,7 @@ def _format_type(tp: type) -> str: """ def get_base(tp: type) -> str: + return tp.__name__ if hasattr(tp, '__name__') else str(tp) return tp.__name__ if hasattr(tp, '__name__') else str(tp).split('.')[-1] if (origin := get_origin(tp)) is not None: diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py index 30e116f5..8d8007af 100644 --- a/tests/variadic_workflow_test.py +++ b/tests/variadic_workflow_test.py @@ -68,7 +68,8 @@ def make_param() -> Param: assert graph.compute() == 2 graph = pipeline.get(Result) assert graph.compute() == 66.0 - # graph.visualize().render('graph', format='png') + graph.visualize().render('graph', format='png') from dask.delayed import Delayed - Delayed(Result, graph._graph).visualize(filename='graph.png') + dsk = {key: (value, *args) for key, (value, args) in graph._graph.items()} + # Delayed(Result, dsk).visualize(filename='graph.png') From bc4acafba506d50e95fd0f6ee3e839f947daeb3c Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 31 Jul 2023 17:30:20 +0200 Subject: [PATCH 13/87] Try compact visualization --- src/sciline/pipeline.py | 7 ----- src/sciline/visualize.py | 46 ++++++++++++++++++++++++++------- tests/variadic_workflow_test.py | 5 ++-- 3 files changed, 38 insertions(+), 20 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 392eac3e..1bbbd26f 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -83,19 +83,12 @@ class Label: tp: type index: int - def __str__(self) -> str: - return f'{self.tp.__name__}[{self.index}]' - @dataclass(frozen=True) class Item: label: label tp: type - def __str__(self) -> str: - name = self.tp.__name__ if hasattr(self.tp, '__name__') else str(self.tp) - return f'{name}[{self.label}]' - def _indexed_key(index_name, i, value_name): label = Label(index_name, i) diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index 2ee194ab..3c9ec236 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import Any, Callable, Dict, List, Tuple, get_args, get_origin +from typing import Any, Callable, Dict, List, Tuple, Union, get_args, get_origin from graphviz import Digraph +from .pipeline import Item, Label from .scheduler import Graph @@ -20,18 +21,26 @@ def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph: """ dot = Digraph(strict=True, **kwargs) for p, (p_name, args, ret) in _format_graph(graph).items(): - dot.node(ret, ret, shape='rectangle') + shape = 'rectangle' if '(' not in ret else 'box3d' + dot.node(ret, ret, shape=shape) # Do not draw dummy providers created by Pipeline when setting instances if p_name in ( 'Pipeline.__setitem__..', 'Pipeline.set_index..', ): continue - dot.node(p, p_name, shape='ellipse') - for arg in args: - dot.node(arg, arg, shape='rectangle') - dot.edge(arg, p) - dot.edge(p, ret) + if p_name.startswith('Pipeline._make_mapping_provider.'): + for arg in args: + shape = 'rectangle' if '(' not in arg else 'box3d' + dot.node(arg, arg, shape=shape) + dot.edge(arg, ret) + else: + dot.node(p, p_name, shape='ellipse') + for arg in args: + shape = 'rectangle' if '(' not in arg else 'box3d' + dot.node(arg, arg, shape=shape) + dot.edge(arg, p) + dot.edge(p, ret) return dot @@ -50,6 +59,16 @@ def _format_provider(provider: Callable[..., Any], ret: type) -> str: return f'{provider.__qualname__}_{_format_type(ret)}' +def _extract_type_and_labels(key: Union[Item, type]) -> Tuple[type, List[type]]: + if isinstance(key, Item): + tp, labels = _extract_type_and_labels(key.tp) + return tp, [key.label.tp] + labels + if isinstance(key, Label): + tp, labels = _extract_type_and_labels(key.tp) + return tp, [key.tp] + labels + return key, [] + + def _format_type(tp: type) -> str: """ Helper for _format_graph. @@ -59,12 +78,19 @@ def _format_type(tp: type) -> str: We may make this configurable in the future. """ + print(_extract_type_and_labels(tp)) + tp, labels = _extract_type_and_labels(tp) + def get_base(tp: type) -> str: - return tp.__name__ if hasattr(tp, '__name__') else str(tp) return tp.__name__ if hasattr(tp, '__name__') else str(tp).split('.')[-1] + def with_labels(base: str) -> str: + if labels: + return f'{base}({", ".join([get_base(l) for l in labels])})' + return base + if (origin := get_origin(tp)) is not None: params = [_format_type(param) for param in get_args(tp)] - return f'{get_base(origin)}[{", ".join(params)}]' + return with_labels(f'{get_base(origin)}[{", ".join(params)}]') else: - return get_base(tp) + return with_labels(get_base(tp)) diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py index 8d8007af..45ef18db 100644 --- a/tests/variadic_workflow_test.py +++ b/tests/variadic_workflow_test.py @@ -69,7 +69,6 @@ def make_param() -> Param: graph = pipeline.get(Result) assert graph.compute() == 66.0 graph.visualize().render('graph', format='png') - from dask.delayed import Delayed - - dsk = {key: (value, *args) for key, (value, args) in graph._graph.items()} + # from dask.delayed import Delayed + # dsk = {key: (value, *args) for key, (value, args) in graph._graph.items()} # Delayed(Result, dsk).visualize(filename='graph.png') From 5f8f32b88806c14e7840949a0d7d06b86e0225a8 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 1 Aug 2023 08:56:01 +0200 Subject: [PATCH 14/87] Cleanup --- src/sciline/visualize.py | 13 ++++++------- tests/variadic_pipeline_test.py | 6 +++--- tests/variadic_workflow_test.py | 2 +- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index 3c9ec236..aa5feaa0 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -21,24 +21,24 @@ def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph: """ dot = Digraph(strict=True, **kwargs) for p, (p_name, args, ret) in _format_graph(graph).items(): - shape = 'rectangle' if '(' not in ret else 'box3d' - dot.node(ret, ret, shape=shape) + if '(' in ret: + dot.node(ret, ret, shape='box3d') + else: + dot.node(ret, ret, shape='rectangle') # Do not draw dummy providers created by Pipeline when setting instances if p_name in ( 'Pipeline.__setitem__..', 'Pipeline.set_index..', ): continue + # Do not draw the internal provider gathering index-dependent results into + # a dict if p_name.startswith('Pipeline._make_mapping_provider.'): for arg in args: - shape = 'rectangle' if '(' not in arg else 'box3d' - dot.node(arg, arg, shape=shape) dot.edge(arg, ret) else: dot.node(p, p_name, shape='ellipse') for arg in args: - shape = 'rectangle' if '(' not in arg else 'box3d' - dot.node(arg, arg, shape=shape) dot.edge(arg, p) dot.edge(p, ret) return dot @@ -78,7 +78,6 @@ def _format_type(tp: type) -> str: We may make this configurable in the future. """ - print(_extract_type_and_labels(tp)) tp, labels = _extract_type_and_labels(tp) def get_base(tp: type) -> str: diff --git a/tests/variadic_pipeline_test.py b/tests/variadic_pipeline_test.py index fce3750c..69b178f6 100644 --- a/tests/variadic_pipeline_test.py +++ b/tests/variadic_pipeline_test.py @@ -6,7 +6,7 @@ from sciline.pipeline import Label -def test_set_index_sets_up_providers_with_indexable_instances(): +def test_set_index_sets_up_providers_with_indexable_instances() -> None: pl = sl.Pipeline() pl.set_index(float, [1.0, 2.0, 3.0]) assert pl.compute(Label(float, 0)) == 1.0 @@ -14,7 +14,7 @@ def test_set_index_sets_up_providers_with_indexable_instances(): assert pl.compute(Label(float, 2)) == 3.0 -def test_can_depend_on_index_elements(): +def test_can_depend_on_index_elements() -> None: def use_index_elem(x: Label(float, 1)) -> int: return int(x) @@ -23,7 +23,7 @@ def use_index_elem(x: Label(float, 1)) -> int: assert pl.compute(int) == 2 -def test_can_gather_index(): +def test_can_gather_index() -> None: Sum = NewType("Sum", float) def gather(x: sl.Map[str, float]) -> Sum: diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py index 45ef18db..d74962cf 100644 --- a/tests/variadic_workflow_test.py +++ b/tests/variadic_workflow_test.py @@ -5,7 +5,7 @@ import sciline as sl -def test_Map(): +def test_Map() -> None: Filename = NewType('Filename', str) Config = NewType('Config', int) Image = NewType('Image', float) From 918423ea069941e490d5f143c212a331b10e858f Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 1 Aug 2023 12:02:07 +0200 Subject: [PATCH 15/87] Switch from nested labeling to Label tuples --- src/sciline/pipeline.py | 74 ++++++++++++++++++++++++---------------- src/sciline/variadic.py | 6 ++-- src/sciline/visualize.py | 6 ++-- 3 files changed, 51 insertions(+), 35 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 1bbbd26f..e3807dfc 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -6,13 +6,15 @@ from typing import ( Any, Callable, + Collection, Dict, - Iterable, + Generic, List, Optional, Tuple, Type, TypeVar, + Union, get_args, get_origin, get_type_hints, @@ -43,8 +45,30 @@ class AmbiguousProvider(Exception): """Raised when multiple providers are found for a type.""" +@dataclass(frozen=True) +class Label(Generic[T]): + tp: Type[T] + index: int + + +@dataclass(frozen=True) +class Item(Generic[T]): + label: Tuple[Label[T], ...] + tp: type + + +def _indexed_key(index_name: Any, i: int, value_name: Any) -> Union[Label, Item]: + if index_name == value_name: + return Label(index_name, i) + label = Label(index_name, i) + if isinstance(value_name, Item): + return Item(value_name.label + (label,), value_name.tp) + else: + return Item((label,), value_name) + + Provider = Callable[..., Any] -Key = type +Key = Union[type, Label, Item] def _is_compatible_type_tuple( @@ -78,23 +102,6 @@ def _bind_free_typevars(tp: TypeVar | Key, bound: Dict[TypeVar, Key]) -> Key: return tp -@dataclass(frozen=True) -class Label: - tp: type - index: int - - -@dataclass(frozen=True) -class Item: - label: label - tp: type - - -def _indexed_key(index_name, i, value_name): - label = Label(index_name, i) - return label if index_name == value_name else Item(label, value_name) - - class Pipeline: """A container for providers that can be assembled into a task graph.""" @@ -117,7 +124,7 @@ def __init__( """ self._providers: Dict[Key, Provider] = {} self._subproviders: Dict[type, Dict[Tuple[Key | TypeVar, ...], Provider]] = {} - self._indices: Dict[Key, Iterable[Any]] = {} + self._indices: Dict[Key, Collection[Any]] = {} for provider in providers or []: self.insert(provider) for tp, param in (params or {}).items(): @@ -137,7 +144,7 @@ def insert(self, provider: Provider, /) -> None: raise ValueError(f'Provider {provider} lacks type-hint for return value') self._set_provider(key, provider) - def __setitem__(self, key: Type[T], param: T) -> None: + def __setitem__(self, key: Union[Type[T], Label[T]], param: T) -> None: """ Provide a concrete value for a type. @@ -175,12 +182,14 @@ def __setitem__(self, key: Type[T], param: T) -> None: ) self._set_provider(key, lambda: param) - def set_index(self, key: Type[T], index: Iterable[T]) -> None: + def set_index(self, key: Type[T], index: Collection[T]) -> None: self._indices[key] = index for i, label in enumerate(index): self._set_provider(Label(tp=key, index=i), lambda label=label: label) - def _set_provider(self, key: Type[T], provider: Callable[..., T]) -> None: + def _set_provider( + self, key: Union[Type[T], Label[T]], provider: Callable[..., T] + ) -> None: # isinstance does not work here and types.NoneType available only in 3.10+ if key == type(None): # noqa: E721 raise ValueError(f'Provider {provider} returning `None` is not allowed') @@ -195,7 +204,9 @@ def _set_provider(self, key: Type[T], provider: Callable[..., T]) -> None: raise ValueError(f'Provider for {key} already exists') self._providers[key] = provider - def _get_provider(self, tp: Type[T]) -> Tuple[Callable[..., T], Dict[TypeVar, Key]]: + def _get_provider( + self, tp: Union[Type[T], Label[T], Item] + ) -> Tuple[Callable[..., T], Dict[TypeVar, Key]]: if (provider := self._providers.get(tp)) is not None: return provider, {} elif (origin := get_origin(tp)) is not None and ( @@ -235,7 +246,7 @@ def build(self, tp: Type[T], /) -> Graph: Type to build the graph for. """ graph: Graph = {} - stack: List[Type[T]] = [tp] + stack: List[Union[Type[T], Label[T]]] = [tp] while stack: tp = stack.pop() if tp in self._indices: @@ -273,11 +284,14 @@ def _build_indexed_subgraph(self, tp: Type[T]) -> Graph: if key in path: for i in range(size): provider, args = value - args = tuple( + args_with_index = tuple( _indexed_key(index_name, i, arg) if arg in path else arg for arg in args ) - graph[_indexed_key(index_name, i, key)] = (provider, args) + graph[_indexed_key(index_name, i, key)] = ( + provider, + args_with_index, + ) else: graph[key] = value for i in range(len(self._indices[index_name])): @@ -285,8 +299,10 @@ def _build_indexed_subgraph(self, tp: Type[T]) -> Graph: graph[Label(index_name, i)] = (provider, ()) return graph - def _make_mapping_provider(self, value_type, mapping_type, index_name): - def provider(*args: value_type) -> mapping_type: + def _make_mapping_provider( + self, value_type: type, mapping_type: type, index_name: type + ) -> Callable[..., Any]: + def provider(*args): # type: ignore[no-untyped-def] return mapping_type(dict(zip(self._indices[index_name], args))) return provider diff --git a/src/sciline/variadic.py b/src/sciline/variadic.py index d635cd53..57ef2947 100644 --- a/src/sciline/variadic.py +++ b/src/sciline/variadic.py @@ -2,14 +2,14 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from __future__ import annotations -from collections.abc import Mapping -from typing import Iterator, TypeVar +from collections import abc +from typing import Generic, Iterator, Mapping, TypeVar Key = TypeVar('Key') Value = TypeVar('Value') -class Map(Mapping[Key, Value]): +class Map(abc.Mapping, Generic[Key, Value]): def __init__(self, values: Mapping[Key, Value]) -> None: self._map: Mapping[Key, Value] = values diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index aa5feaa0..9ac326ac 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -59,10 +59,10 @@ def _format_provider(provider: Callable[..., Any], ret: type) -> str: return f'{provider.__qualname__}_{_format_type(ret)}' -def _extract_type_and_labels(key: Union[Item, type]) -> Tuple[type, List[type]]: +def _extract_type_and_labels(key: Union[Item, Label, type]) -> Tuple[type, List[type]]: if isinstance(key, Item): - tp, labels = _extract_type_and_labels(key.tp) - return tp, [key.label.tp] + labels + label = key.label + return key.tp, [lb.tp for lb in label] if isinstance(key, Label): tp, labels = _extract_type_and_labels(key.tp) return tp, [key.tp] + labels From e970da79f5f6716e11a8a4289d4b50b606c01cdc Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 08:27:48 +0200 Subject: [PATCH 16/87] Begin refactor to param table instead of bare index --- src/sciline/pipeline.py | 36 ++++++++++++++++--- ...ne_test.py => pipeline_with_index_test.py} | 27 ++++++++++---- 2 files changed, 52 insertions(+), 11 deletions(-) rename tests/{variadic_pipeline_test.py => pipeline_with_index_test.py} (59%) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index e3807dfc..8984dcc6 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -105,6 +105,8 @@ def _bind_free_typevars(tp: TypeVar | Key, bound: Dict[TypeVar, Key]) -> Key: class Pipeline: """A container for providers that can be assembled into a task graph.""" + _param_sentinel = object() + def __init__( self, providers: Optional[List[Provider]] = None, @@ -125,6 +127,7 @@ def __init__( self._providers: Dict[Key, Provider] = {} self._subproviders: Dict[type, Dict[Tuple[Key | TypeVar, ...], Provider]] = {} self._indices: Dict[Key, Collection[Any]] = {} + self._param_series: Dict[Key, Key] = {} for provider in providers or []: self.insert(provider) for tp, param in (params or {}).items(): @@ -187,6 +190,19 @@ def set_index(self, key: Type[T], index: Collection[T]) -> None: for i, label in enumerate(index): self._set_provider(Label(tp=key, index=i), lambda label=label: label) + def set_param_table( + self, key: Type[T], value_series: Dict[type, Collection] + ) -> None: + self._indices[key] = list(range(len(next(iter(value_series.values()))))) + for param_type in value_series: + self._param_series[param_type] = key + for param_type, values in value_series.items(): + for i, label in enumerate(values): + self._set_provider( + Item((Label(tp=key, index=i),), param_type), + lambda label=label: label, + ) + def _set_provider( self, key: Union[Type[T], Label[T]], provider: Callable[..., T] ) -> None: @@ -249,7 +265,8 @@ def build(self, tp: Type[T], /) -> Graph: stack: List[Union[Type[T], Label[T]]] = [tp] while stack: tp = stack.pop() - if tp in self._indices: + if tp in self._param_series: + graph[tp] = (self._param_sentinel, ()) continue if get_origin(tp) == Map: graph.update(self._build_indexed_subgraph(tp)) @@ -279,11 +296,23 @@ def _build_indexed_subgraph(self, tp: Type[T]) -> Graph: graph[tp] = (provider, args) subgraph = self.build(value_type) - path = find_nodes_in_paths(subgraph, value_type, index_name) + param_name = None + for k, (provider, _) in subgraph.items(): + if provider == self._param_sentinel: + if param_name is not None: + raise ValueError( + f'Found multiple param names in subgraph: {param_name}, {k}' + ) + param_name = k + path = find_nodes_in_paths(subgraph, value_type, param_name) for key, value in subgraph.items(): if key in path: for i in range(size): provider, args = value + if provider == self._param_sentinel: + provider, _ = self._get_provider( + _indexed_key(index_name, i, key) + ) args_with_index = tuple( _indexed_key(index_name, i, arg) if arg in path else arg for arg in args @@ -294,9 +323,6 @@ def _build_indexed_subgraph(self, tp: Type[T]) -> Graph: ) else: graph[key] = value - for i in range(len(self._indices[index_name])): - provider, _ = self._get_provider(Label(index_name, i)) - graph[Label(index_name, i)] = (provider, ()) return graph def _make_mapping_provider( diff --git a/tests/variadic_pipeline_test.py b/tests/pipeline_with_index_test.py similarity index 59% rename from tests/variadic_pipeline_test.py rename to tests/pipeline_with_index_test.py index 69b178f6..c3196660 100644 --- a/tests/variadic_pipeline_test.py +++ b/tests/pipeline_with_index_test.py @@ -25,17 +25,32 @@ def use_index_elem(x: Label(float, 1)) -> int: def test_can_gather_index() -> None: Sum = NewType("Sum", float) + Name = NewType("Name", str) - def gather(x: sl.Map[str, float]) -> Sum: - print(list(x.items())) + def gather(x: sl.Map[Name, float]) -> Sum: return Sum(sum(x.values())) def make_float(x: str) -> float: return float(x) pl = sl.Pipeline([gather, make_float]) - pl.set_index(str, ["1.0", "2.0", "3.0"]) - graph = pl.build(Sum) - for key, value in graph.items(): - print(key, value) + pl.set_param_table(Name, {str: ["1.0", "2.0", "3.0"]}) assert pl.compute(Sum) == 6.0 + + +def test_can_zip() -> None: + Sum = NewType("Sum", float) + Str = NewType("Str", str) + Run = NewType("Run", int) + + def gather_zip(x: sl.Map[Run, Str], y: sl.Map[Run, int]) -> Sum: + z = [f'{x_}{y_}' for x_, y_ in zip(x.values(), y.values())] + return Sum(str(z)) + + def use_str(x: str) -> Str: + return Str(x) + + pl = sl.Pipeline([gather_zip, use_str]) + pl.set_param_table(Run, {str: ['a', 'a', 'ccc'], int: [1, 2, 3]}) + + assert pl.compute(Sum) == "['a1', 'a2', 'ccc3']" From aee94078d834825c637ce27a82bde4e0157ab97f Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 08:37:39 +0200 Subject: [PATCH 17/87] Readability --- src/sciline/pipeline.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 8984dcc6..70ab40cb 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -309,18 +309,14 @@ def _build_indexed_subgraph(self, tp: Type[T]) -> Graph: if key in path: for i in range(size): provider, args = value + subkey = _indexed_key(index_name, i, key) if provider == self._param_sentinel: - provider, _ = self._get_provider( - _indexed_key(index_name, i, key) - ) + provider, _ = self._get_provider(subkey) args_with_index = tuple( _indexed_key(index_name, i, arg) if arg in path else arg for arg in args ) - graph[_indexed_key(index_name, i, key)] = ( - provider, - args_with_index, - ) + graph[subkey] = (provider, args_with_index) else: graph[key] = value return graph From e0875c70408a3c1f09552a936f7c5f8215408032 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 08:47:37 +0200 Subject: [PATCH 18/87] Support diamond dependency on same index --- src/sciline/pipeline.py | 13 +++---------- tests/pipeline_with_index_test.py | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 70ab40cb..20160748 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -266,7 +266,7 @@ def build(self, tp: Type[T], /) -> Graph: while stack: tp = stack.pop() if tp in self._param_series: - graph[tp] = (self._param_sentinel, ()) + graph[tp] = (self._param_sentinel, (self._param_series[tp],)) continue if get_origin(tp) == Map: graph.update(self._build_indexed_subgraph(tp)) @@ -296,15 +296,7 @@ def _build_indexed_subgraph(self, tp: Type[T]) -> Graph: graph[tp] = (provider, args) subgraph = self.build(value_type) - param_name = None - for k, (provider, _) in subgraph.items(): - if provider == self._param_sentinel: - if param_name is not None: - raise ValueError( - f'Found multiple param names in subgraph: {param_name}, {k}' - ) - param_name = k - path = find_nodes_in_paths(subgraph, value_type, param_name) + path = find_nodes_in_paths(subgraph, value_type, index_name) for key, value in subgraph.items(): if key in path: for i in range(size): @@ -312,6 +304,7 @@ def _build_indexed_subgraph(self, tp: Type[T]) -> Graph: subkey = _indexed_key(index_name, i, key) if provider == self._param_sentinel: provider, _ = self._get_provider(subkey) + args = () args_with_index = tuple( _indexed_key(index_name, i, arg) if arg in path else arg for arg in args diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index c3196660..75fc0c4b 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -54,3 +54,23 @@ def use_str(x: str) -> Str: pl.set_param_table(Run, {str: ['a', 'a', 'ccc'], int: [1, 2, 3]}) assert pl.compute(Sum) == "['a1', 'a2', 'ccc3']" + + +def test_diamond_dependency_works() -> None: + Sum = NewType("Sum", float) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Row = NewType("Run", int) + + def gather( + x: sl.Map[Row, float], + ) -> Sum: + return Sum(sum(x.values())) + + def join(x: Param1, y: Param2) -> float: + return x / y + + pl = sl.Pipeline([gather, join]) + pl.set_param_table(Row, {Param1: [1, 4, 9], Param2: [1, 2, 3]}) + + assert pl.compute(Sum) == 6 From 4fda7cda0ebd9c1394717555836ec4821b42b684 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 09:08:46 +0200 Subject: [PATCH 19/87] Remove old API --- src/sciline/pipeline.py | 5 ----- tests/pipeline_with_index_test.py | 24 +++++++++++------------- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 20160748..4d0663f2 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -185,11 +185,6 @@ def __setitem__(self, key: Union[Type[T], Label[T]], param: T) -> None: ) self._set_provider(key, lambda: param) - def set_index(self, key: Type[T], index: Collection[T]) -> None: - self._indices[key] = index - for i, label in enumerate(index): - self._set_provider(Label(tp=key, index=i), lambda label=label: label) - def set_param_table( self, key: Type[T], value_series: Dict[type, Collection] ) -> None: diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 75fc0c4b..6d7b18f3 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -3,24 +3,22 @@ from typing import NewType import sciline as sl -from sciline.pipeline import Label +from sciline.pipeline import Item, Label -def test_set_index_sets_up_providers_with_indexable_instances() -> None: +def test_can_get_elements_of_param_table() -> None: pl = sl.Pipeline() - pl.set_index(float, [1.0, 2.0, 3.0]) - assert pl.compute(Label(float, 0)) == 1.0 - assert pl.compute(Label(float, 1)) == 2.0 - assert pl.compute(Label(float, 2)) == 3.0 + pl.set_param_table(int, {float: [1.0, 2.0, 3.0]}) + assert pl.compute(Item((Label(int, 1),), float)) == 2.0 -def test_can_depend_on_index_elements() -> None: - def use_index_elem(x: Label(float, 1)) -> int: - return int(x) +def test_can_depend_on_elements_of_param_table() -> None: + def use_elem(x: Item((Label(int, 1),), float)) -> str: + return str(x) - pl = sl.Pipeline([use_index_elem]) - pl.set_index(float, [1.0, 2.0, 3.0]) - assert pl.compute(int) == 2 + pl = sl.Pipeline([use_elem]) + pl.set_param_table(int, {float: [1.0, 2.0, 3.0]}) + assert pl.compute(str) == "2.0" def test_can_gather_index() -> None: @@ -56,7 +54,7 @@ def use_str(x: str) -> Str: assert pl.compute(Sum) == "['a1', 'a2', 'ccc3']" -def test_diamond_dependency_works() -> None: +def test_diamond_dependency_pulls_values_from_columns_in_same_param_table() -> None: Sum = NewType("Sum", float) Param1 = NewType("Param1", int) Param2 = NewType("Param2", int) From 267d7d98c7586b78795c55247a58f278069526f0 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 09:38:43 +0200 Subject: [PATCH 20/87] Add ParamTable --- src/sciline/__init__.py | 2 ++ src/sciline/param_table.py | 47 +++++++++++++++++++++++++++++++ src/sciline/pipeline.py | 22 +++++++-------- tests/param_table_test.py | 37 ++++++++++++++++++++++++ tests/pipeline_with_index_test.py | 10 +++---- 5 files changed, 101 insertions(+), 17 deletions(-) create mode 100644 src/sciline/param_table.py create mode 100644 tests/param_table_test.py diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index 79dc10f0..6c103291 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -10,6 +10,7 @@ __version__ = "0.0.0" from .domain import Scope +from .param_table import ParamTable from .pipeline import ( AmbiguousProvider, Pipeline, @@ -21,6 +22,7 @@ __all__ = [ "AmbiguousProvider", "Map", + "ParamTable", "Pipeline", "Scope", "UnboundTypeVar", diff --git a/src/sciline/param_table.py b/src/sciline/param_table.py new file mode 100644 index 00000000..7bac7955 --- /dev/null +++ b/src/sciline/param_table.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +from collections import abc +from typing import Any, Collection, Dict, Optional + + +class ParamTable(abc.Mapping): + def __init__( + self, + row_dim: type, + columns: Dict[type, Collection[Any]], + *, + index: Optional[Collection[Any]] = None, + ): + sizes = set(len(v) for v in columns.values()) + if len(sizes) != 1: + raise ValueError( + f"Columns in param table must all have same size, got {sizes}" + ) + self._row_dim = row_dim + self._columns = columns + self._index = index or list(range(sizes.pop())) + + @property + def row_dim(self) -> type: + return self._row_dim + + @property + def index(self) -> Collection[Any]: + return self._index + + def __contains__(self, __key: object) -> bool: + return self._columns.__contains__(__key) + + def __getitem__(self, __key: object) -> Any: + return self._columns.__getitem__(__key) + + def __iter__(self) -> Any: + return self._columns.__iter__() + + def __len__(self) -> int: + return self._columns.__len__() + + def __repr__(self) -> str: + return f"ParamTable(row_dim={self.row_dim}, columns={self.columns})" diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 4d0663f2..1b625039 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -6,7 +6,6 @@ from typing import ( Any, Callable, - Collection, Dict, Generic, List, @@ -25,6 +24,7 @@ from .domain import Scope from .graph import find_nodes_in_paths +from .param_table import ParamTable from .scheduler import Graph, Scheduler from .variadic import Map @@ -126,7 +126,7 @@ def __init__( """ self._providers: Dict[Key, Provider] = {} self._subproviders: Dict[type, Dict[Tuple[Key | TypeVar, ...], Provider]] = {} - self._indices: Dict[Key, Collection[Any]] = {} + self._param_tables: Dict[Key, ParamTable] = {} self._param_series: Dict[Key, Key] = {} for provider in providers or []: self.insert(provider) @@ -185,16 +185,14 @@ def __setitem__(self, key: Union[Type[T], Label[T]], param: T) -> None: ) self._set_provider(key, lambda: param) - def set_param_table( - self, key: Type[T], value_series: Dict[type, Collection] - ) -> None: - self._indices[key] = list(range(len(next(iter(value_series.values()))))) - for param_type in value_series: - self._param_series[param_type] = key - for param_type, values in value_series.items(): + def set_param_table(self, params: ParamTable) -> None: + self._param_tables[params.row_dim] = params + for param_type in params: + self._param_series[param_type] = params.row_dim + for param_type, values in params.items(): for i, label in enumerate(values): self._set_provider( - Item((Label(tp=key, index=i),), param_type), + Item((Label(tp=params.row_dim, index=i),), param_type), lambda label=label: label, ) @@ -282,7 +280,7 @@ def build(self, tp: Type[T], /) -> Graph: def _build_indexed_subgraph(self, tp: Type[T]) -> Graph: index_name, value_type = get_args(tp) - size = len(self._indices[index_name]) + size = len(self._param_tables[index_name].index) provider = self._make_mapping_provider( value_type=value_type, mapping_type=tp, index_name=index_name ) @@ -313,7 +311,7 @@ def _make_mapping_provider( self, value_type: type, mapping_type: type, index_name: type ) -> Callable[..., Any]: def provider(*args): # type: ignore[no-untyped-def] - return mapping_type(dict(zip(self._indices[index_name], args))) + return mapping_type(dict(zip(self._param_tables[index_name].index, args))) return provider diff --git a/tests/param_table_test.py b/tests/param_table_test.py new file mode 100644 index 00000000..b8775204 --- /dev/null +++ b/tests/param_table_test.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +import pytest + +import sciline as sl + + +def test_raises_with_zero_columns() -> None: + with pytest.raises(ValueError): + sl.ParamTable(row_dim=int, columns={}) + + +def test_raises_with_inconsistent_column_sizes() -> None: + with pytest.raises(ValueError): + sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0]}) + + +def test_contains_includes_all_columns() -> None: + pt = sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0, 3.0]}) + assert int in pt + assert float in pt + assert str not in pt + + +def test_contains_does_not_include_index() -> None: + pt = sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0, 3.0]}) + assert int not in pt + + +def test_len_is_number_of_columns() -> None: + pt = sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0, 3.0]}) + assert len(pt) == 2 + + +def test_defaults_to_range_index() -> None: + pt = sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0, 3.0]}) + assert pt.index == [0, 1, 2] diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 6d7b18f3..6c04102c 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -8,7 +8,7 @@ def test_can_get_elements_of_param_table() -> None: pl = sl.Pipeline() - pl.set_param_table(int, {float: [1.0, 2.0, 3.0]}) + pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) assert pl.compute(Item((Label(int, 1),), float)) == 2.0 @@ -17,7 +17,7 @@ def use_elem(x: Item((Label(int, 1),), float)) -> str: return str(x) pl = sl.Pipeline([use_elem]) - pl.set_param_table(int, {float: [1.0, 2.0, 3.0]}) + pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) assert pl.compute(str) == "2.0" @@ -32,7 +32,7 @@ def make_float(x: str) -> float: return float(x) pl = sl.Pipeline([gather, make_float]) - pl.set_param_table(Name, {str: ["1.0", "2.0", "3.0"]}) + pl.set_param_table(sl.ParamTable(Name, {str: ["1.0", "2.0", "3.0"]})) assert pl.compute(Sum) == 6.0 @@ -49,7 +49,7 @@ def use_str(x: str) -> Str: return Str(x) pl = sl.Pipeline([gather_zip, use_str]) - pl.set_param_table(Run, {str: ['a', 'a', 'ccc'], int: [1, 2, 3]}) + pl.set_param_table(sl.ParamTable(Run, {str: ['a', 'a', 'ccc'], int: [1, 2, 3]})) assert pl.compute(Sum) == "['a1', 'a2', 'ccc3']" @@ -69,6 +69,6 @@ def join(x: Param1, y: Param2) -> float: return x / y pl = sl.Pipeline([gather, join]) - pl.set_param_table(Row, {Param1: [1, 4, 9], Param2: [1, 2, 3]}) + pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 4, 9], Param2: [1, 2, 3]})) assert pl.compute(Sum) == 6 From de81ba85153eda205278ef0c5c2ed350c61c0828 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 09:42:54 +0200 Subject: [PATCH 21/87] Check index length --- src/sciline/param_table.py | 7 ++++++- tests/param_table_test.py | 5 +++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/sciline/param_table.py b/src/sciline/param_table.py index 7bac7955..8f9d0715 100644 --- a/src/sciline/param_table.py +++ b/src/sciline/param_table.py @@ -19,9 +19,14 @@ def __init__( raise ValueError( f"Columns in param table must all have same size, got {sizes}" ) + size = sizes.pop() + if index is not None and len(index) != sizes: + raise ValueError( + f"Index and columns must have same size, got {len(index)} and {size}" + ) self._row_dim = row_dim self._columns = columns - self._index = index or list(range(sizes.pop())) + self._index = index or list(range(size)) @property def row_dim(self) -> type: diff --git a/tests/param_table_test.py b/tests/param_table_test.py index b8775204..6bcfd482 100644 --- a/tests/param_table_test.py +++ b/tests/param_table_test.py @@ -15,6 +15,11 @@ def test_raises_with_inconsistent_column_sizes() -> None: sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0]}) +def test_raises_with_inconsistent_index_length() -> None: + with pytest.raises(ValueError): + sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0]}, index=[1, 2, 3]) + + def test_contains_includes_all_columns() -> None: pt = sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0, 3.0]}) assert int in pt From 37f24a160f431db6f20c1dbd5229f7fbca284161 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 09:56:23 +0200 Subject: [PATCH 22/87] Prevent breaking set_param_table --- src/sciline/pipeline.py | 17 +++++++++++++---- tests/pipeline_with_index_test.py | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 1b625039..6805ef9c 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -185,14 +185,23 @@ def __setitem__(self, key: Union[Type[T], Label[T]], param: T) -> None: ) self._set_provider(key, lambda: param) + @property + def param_tables(self) -> Dict[Key, ParamTable]: + return dict(self._param_tables) + def set_param_table(self, params: ParamTable) -> None: + if params.row_dim in self._param_tables: + raise ValueError(f'Parameter table for {params.row_dim} already set') + for param_name in params: + if param_name in self._param_series: + raise ValueError(f'Parameter {param_name} already set') self._param_tables[params.row_dim] = params - for param_type in params: - self._param_series[param_type] = params.row_dim - for param_type, values in params.items(): + for param_name in params: + self._param_series[param_name] = params.row_dim + for param_name, values in params.items(): for i, label in enumerate(values): self._set_provider( - Item((Label(tp=params.row_dim, index=i),), param_type), + Item((Label(tp=params.row_dim, index=i),), param_name), lambda label=label: label, ) diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 6c04102c..128051fd 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -2,10 +2,28 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from typing import NewType +import pytest + import sciline as sl from sciline.pipeline import Item, Label +def test_set_param_table_raises_if_param_names_are_duplicate(): + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) + with pytest.raises(ValueError): + pl.set_param_table(sl.ParamTable(str, {float: [1.0, 2.0, 3.0]})) + assert str not in pl.param_tables + + +def test_set_param_table_raises_if_row_dim_is_duplicate(): + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) + with pytest.raises(ValueError): + pl.set_param_table(sl.ParamTable(int, {str: ['a', 'b', 'c']})) + assert pl.param_tables[int] == sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}) + + def test_can_get_elements_of_param_table() -> None: pl = sl.Pipeline() pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) From d3a56ea1dab090be14bb635f3be1a29525000adb Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 10:05:15 +0200 Subject: [PATCH 23/87] Simplify code --- src/sciline/pipeline.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 6805ef9c..5ea2fb71 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -289,13 +289,11 @@ def build(self, tp: Type[T], /) -> Graph: def _build_indexed_subgraph(self, tp: Type[T]) -> Graph: index_name, value_type = get_args(tp) - size = len(self._param_tables[index_name].index) - provider = self._make_mapping_provider( - value_type=value_type, mapping_type=tp, index_name=index_name - ) + index = self._param_tables[index_name].index + size = len(index) args = [_indexed_key(index_name, i, value_type) for i in range(size)] graph: Graph = {} - graph[tp] = (provider, args) + graph[tp] = (lambda *values: Map(dict(zip(index, values))), args) subgraph = self.build(value_type) path = find_nodes_in_paths(subgraph, value_type, index_name) @@ -316,14 +314,6 @@ def _build_indexed_subgraph(self, tp: Type[T]) -> Graph: graph[key] = value return graph - def _make_mapping_provider( - self, value_type: type, mapping_type: type, index_name: type - ) -> Callable[..., Any]: - def provider(*args): # type: ignore[no-untyped-def] - return mapping_type(dict(zip(self._param_tables[index_name].index, args))) - - return provider - @overload def compute(self, tp: Type[T]) -> T: ... From f2389480e2fb49db7362b6b41182527956b538cc Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 10:24:05 +0200 Subject: [PATCH 24/87] Refactor remaining tests --- src/sciline/pipeline.py | 6 +++++- src/sciline/visualize.py | 4 ++-- tests/variadic_workflow_test.py | 12 +++++++----- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 5ea2fb71..d2fcafdd 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -29,6 +29,8 @@ from .variadic import Map T = TypeVar('T') +KeyType = TypeVar('KeyType') +ValueType = TypeVar('ValueType') class UnsatisfiedRequirement(Exception): @@ -287,7 +289,9 @@ def build(self, tp: Type[T], /) -> Graph: stack.append(arg) return graph - def _build_indexed_subgraph(self, tp: Type[T]) -> Graph: + def _build_indexed_subgraph(self, tp: Type[Map[KeyType, ValueType]]) -> Graph: + index_name: Type[KeyType] + value_type: Type[ValueType] index_name, value_type = get_args(tp) index = self._param_tables[index_name].index size = len(index) diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index 9ac326ac..268ebb6e 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -28,12 +28,12 @@ def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph: # Do not draw dummy providers created by Pipeline when setting instances if p_name in ( 'Pipeline.__setitem__..', - 'Pipeline.set_index..', + 'Pipeline.set_param_table..', ): continue # Do not draw the internal provider gathering index-dependent results into # a dict - if p_name.startswith('Pipeline._make_mapping_provider.'): + if p_name.startswith('Pipeline._build_indexed_subgraph.'): for arg in args: dot.edge(arg, ret) else: diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py index d74962cf..42040c19 100644 --- a/tests/variadic_workflow_test.py +++ b/tests/variadic_workflow_test.py @@ -6,6 +6,8 @@ def test_Map() -> None: + Run = NewType('Run', int) + Setting = NewType('Setting', int) Filename = NewType('Filename', str) Config = NewType('Config', int) Image = NewType('Image', float) @@ -31,14 +33,14 @@ def scale(x: CleanedImage, param: Param, config: Config) -> ScaledImage: return x * param + config def combine_old( - images: sl.Map[Filename, ScaledImage], params: sl.Map[Filename, ImageParam] + images: sl.Map[Run, ScaledImage], params: sl.Map[Run, ImageParam] ) -> float: return sum(images.values()) - def combine(images: sl.Map[Filename, ScaledImage]) -> float: + def combine(images: sl.Map[Run, ScaledImage]) -> float: return sum(images.values()) - def combine_configs(x: sl.Map[Config, float]) -> Result: + def combine_configs(x: sl.Map[Setting, float]) -> Result: return Result(sum(x.values())) def make_int() -> int: @@ -61,8 +63,8 @@ def make_param() -> Param: image_param, ] ) - pipeline.set_index(Filename, filenames) - pipeline.set_index(Config, configs) + pipeline.set_param_table(sl.ParamTable(Run, {Filename: filenames})) + pipeline.set_param_table(sl.ParamTable(Setting, {Config: configs})) graph = pipeline.get(int) assert graph.compute() == 2 From 9cc8988241199487b0d5aec781a5a18cb36c9b3c Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 10:39:24 +0200 Subject: [PATCH 25/87] Remove unused --- src/sciline/graph.py | 35 +++-------------------------------- tests/graph_test.py | 18 +++++------------- 2 files changed, 8 insertions(+), 45 deletions(-) diff --git a/src/sciline/graph.py b/src/sciline/graph.py index 2b2dfded..1a897156 100644 --- a/src/sciline/graph.py +++ b/src/sciline/graph.py @@ -3,36 +3,6 @@ T = TypeVar("T") -def find_path(graph, start: T, end: T) -> List[T]: - """Find a path from start to end in a DAG.""" - if start == end: - return [start] - for node in graph[start]: - path = find_path(graph, node, end) - if path: - return [start] + path - return [] - - -def find_unique_path(graph, start: T, end: T) -> List[T]: - """Find a path from start to end in a DAG. - - Like find_path, but raises if more than one path found - """ - if start == end: - return [start] - if start not in graph: - return [] - paths = [] - for node in graph[start]: - path = find_unique_path(graph, node, end) - if path: - paths.append([start] + path) - if len(paths) > 1: - raise RuntimeError(f"Multiple paths found from {start} to {end}") - return paths[0] if paths else [] - - def find_all_paths(graph, start: T, end: T) -> List[List[T]]: """Find all paths from start to end in a DAG.""" if start == end: @@ -40,14 +10,15 @@ def find_all_paths(graph, start: T, end: T) -> List[List[T]]: if start not in graph: return [] paths = [] - # 0 is the provider, 1 is the args - for node in graph[start][1]: + for node in graph[start]: for path in find_all_paths(graph, node, end): paths.append([start] + path) return paths def find_nodes_in_paths(graph, start: T, end: T) -> List[T]: + # 0 is the provider, 1 is the args + graph = {k: v[1] for k, v in graph.items()} paths = find_all_paths(graph, start, end) nodes = set() for path in paths: diff --git a/tests/graph_test.py b/tests/graph_test.py index 0b3b2848..896cdfd4 100644 --- a/tests/graph_test.py +++ b/tests/graph_test.py @@ -1,18 +1,10 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -import pytest +from sciline.graph import find_all_paths -from sciline.graph import find_path, find_unique_path - -def test_find_path(): - graph = {"D": ["B", "C"], "C": ["A"], "B": ["A"]} - assert find_path(graph, "D", "A") == ["D", "B", "A"] - - -def test_find_unique_path(): +def test_find_all_paths(): graph = {"D": ["B", "C"], "C": ["A"], "B": ["A"]} - with pytest.raises(RuntimeError): - find_unique_path(graph, "D", "A") - graph = {"D": ["B", "C"], "C": ["A"], "B": ["aux"]} - assert find_unique_path(graph, "D", "A") == ["D", "C", "A"] + assert find_all_paths(graph, "D", "A") == [["D", "B", "A"], ["D", "C", "A"]] + assert find_all_paths(graph, "B", "C") == [] + assert find_all_paths(graph, "B", "A") == [["B", "A"]] From cddc86f695dc8c29ffa27c3826e3489053d5f3eb Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 12:49:45 +0200 Subject: [PATCH 26/87] Map->Series --- src/sciline/__init__.py | 4 ++-- src/sciline/pipeline.py | 8 ++++---- src/sciline/{variadic.py => series.py} | 2 +- tests/pipeline_with_index_test.py | 21 ++++++++++++++++++--- tests/variadic_workflow_test.py | 6 +++--- 5 files changed, 28 insertions(+), 13 deletions(-) rename src/sciline/{variadic.py => series.py} (93%) diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index 6c103291..eb74611c 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -17,11 +17,11 @@ UnboundTypeVar, UnsatisfiedRequirement, ) -from .variadic import Map +from .series import Series __all__ = [ "AmbiguousProvider", - "Map", + "Series", "ParamTable", "Pipeline", "Scope", diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index d2fcafdd..f05ae6bf 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -26,7 +26,7 @@ from .graph import find_nodes_in_paths from .param_table import ParamTable from .scheduler import Graph, Scheduler -from .variadic import Map +from .series import Series T = TypeVar('T') KeyType = TypeVar('KeyType') @@ -272,7 +272,7 @@ def build(self, tp: Type[T], /) -> Graph: if tp in self._param_series: graph[tp] = (self._param_sentinel, (self._param_series[tp],)) continue - if get_origin(tp) == Map: + if get_origin(tp) == Series: graph.update(self._build_indexed_subgraph(tp)) continue provider: Callable[..., T] @@ -289,7 +289,7 @@ def build(self, tp: Type[T], /) -> Graph: stack.append(arg) return graph - def _build_indexed_subgraph(self, tp: Type[Map[KeyType, ValueType]]) -> Graph: + def _build_indexed_subgraph(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: index_name: Type[KeyType] value_type: Type[ValueType] index_name, value_type = get_args(tp) @@ -297,7 +297,7 @@ def _build_indexed_subgraph(self, tp: Type[Map[KeyType, ValueType]]) -> Graph: size = len(index) args = [_indexed_key(index_name, i, value_type) for i in range(size)] graph: Graph = {} - graph[tp] = (lambda *values: Map(dict(zip(index, values))), args) + graph[tp] = (lambda *values: Series(dict(zip(index, values))), args) subgraph = self.build(value_type) path = find_nodes_in_paths(subgraph, value_type, index_name) diff --git a/src/sciline/variadic.py b/src/sciline/series.py similarity index 93% rename from src/sciline/variadic.py rename to src/sciline/series.py index 57ef2947..0b2f75f9 100644 --- a/src/sciline/variadic.py +++ b/src/sciline/series.py @@ -9,7 +9,7 @@ Value = TypeVar('Value') -class Map(abc.Mapping, Generic[Key, Value]): +class Series(abc.Mapping, Generic[Key, Value]): def __init__(self, values: Mapping[Key, Value]) -> None: self._map: Mapping[Key, Value] = values diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 128051fd..9684636f 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -39,11 +39,26 @@ def use_elem(x: Item((Label(int, 1),), float)) -> str: assert pl.compute(str) == "2.0" +def test_can_compute_map_of_param_values() -> None: + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) + assert pl.compute(sl.Series[int, float]) == {0: 1.0, 1: 2.0, 2: 3.0} + + +def test_can_compute_map_of_derived_values() -> None: + def process(x: float) -> str: + return str(x) + + pl = sl.Pipeline([process]) + pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) + assert pl.compute(sl.Series[int, str]) == {0: "1.0", 1: "2.0", 2: "3.0"} + + def test_can_gather_index() -> None: Sum = NewType("Sum", float) Name = NewType("Name", str) - def gather(x: sl.Map[Name, float]) -> Sum: + def gather(x: sl.Series[Name, float]) -> Sum: return Sum(sum(x.values())) def make_float(x: str) -> float: @@ -59,7 +74,7 @@ def test_can_zip() -> None: Str = NewType("Str", str) Run = NewType("Run", int) - def gather_zip(x: sl.Map[Run, Str], y: sl.Map[Run, int]) -> Sum: + def gather_zip(x: sl.Series[Run, Str], y: sl.Series[Run, int]) -> Sum: z = [f'{x_}{y_}' for x_, y_ in zip(x.values(), y.values())] return Sum(str(z)) @@ -79,7 +94,7 @@ def test_diamond_dependency_pulls_values_from_columns_in_same_param_table() -> N Row = NewType("Run", int) def gather( - x: sl.Map[Row, float], + x: sl.Series[Row, float], ) -> Sum: return Sum(sum(x.values())) diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py index 42040c19..b7f55a82 100644 --- a/tests/variadic_workflow_test.py +++ b/tests/variadic_workflow_test.py @@ -33,14 +33,14 @@ def scale(x: CleanedImage, param: Param, config: Config) -> ScaledImage: return x * param + config def combine_old( - images: sl.Map[Run, ScaledImage], params: sl.Map[Run, ImageParam] + images: sl.Series[Run, ScaledImage], params: sl.Series[Run, ImageParam] ) -> float: return sum(images.values()) - def combine(images: sl.Map[Run, ScaledImage]) -> float: + def combine(images: sl.Series[Run, ScaledImage]) -> float: return sum(images.values()) - def combine_configs(x: sl.Map[Setting, float]) -> Result: + def combine_configs(x: sl.Series[Setting, float]) -> Result: return Result(sum(x.values())) def make_int() -> int: From 643b13dc4dd89561912c4999cc0e2d30387c552e Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 13:08:24 +0200 Subject: [PATCH 27/87] Some mypy fixes --- src/sciline/graph.py | 12 +++++++----- src/sciline/param_table.py | 12 ++++++------ tests/graph_test.py | 2 +- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/sciline/graph.py b/src/sciline/graph.py index 1a897156..e24a7a49 100644 --- a/src/sciline/graph.py +++ b/src/sciline/graph.py @@ -1,9 +1,9 @@ -from typing import List, TypeVar +from typing import Callable, Collection, List, Mapping, Tuple, TypeVar T = TypeVar("T") -def find_all_paths(graph, start: T, end: T) -> List[List[T]]: +def find_all_paths(graph: Mapping[T, Collection[T]], start: T, end: T) -> List[List[T]]: """Find all paths from start to end in a DAG.""" if start == end: return [[start]] @@ -16,10 +16,12 @@ def find_all_paths(graph, start: T, end: T) -> List[List[T]]: return paths -def find_nodes_in_paths(graph, start: T, end: T) -> List[T]: +def find_nodes_in_paths( + graph: Mapping[T, Tuple[Callable[..., T], Collection[T]]], start: T, end: T +) -> List[T]: # 0 is the provider, 1 is the args - graph = {k: v[1] for k, v in graph.items()} - paths = find_all_paths(graph, start, end) + dependencies = {k: v[1] for k, v in graph.items()} + paths = find_all_paths(dependencies, start, end) nodes = set() for path in paths: nodes.update(path) diff --git a/src/sciline/param_table.py b/src/sciline/param_table.py index 8f9d0715..6ee3fd04 100644 --- a/src/sciline/param_table.py +++ b/src/sciline/param_table.py @@ -20,7 +20,7 @@ def __init__( f"Columns in param table must all have same size, got {sizes}" ) size = sizes.pop() - if index is not None and len(index) != sizes: + if index is not None and len(index) != size: raise ValueError( f"Index and columns must have same size, got {len(index)} and {size}" ) @@ -36,11 +36,11 @@ def row_dim(self) -> type: def index(self) -> Collection[Any]: return self._index - def __contains__(self, __key: object) -> bool: - return self._columns.__contains__(__key) + def __contains__(self, key: Any) -> bool: + return self._columns.__contains__(key) - def __getitem__(self, __key: object) -> Any: - return self._columns.__getitem__(__key) + def __getitem__(self, key: Any) -> Any: + return self._columns.__getitem__(key) def __iter__(self) -> Any: return self._columns.__iter__() @@ -49,4 +49,4 @@ def __len__(self) -> int: return self._columns.__len__() def __repr__(self) -> str: - return f"ParamTable(row_dim={self.row_dim}, columns={self.columns})" + return f"ParamTable(row_dim={self.row_dim}, columns={self._columns})" diff --git a/tests/graph_test.py b/tests/graph_test.py index 896cdfd4..4d78a7f7 100644 --- a/tests/graph_test.py +++ b/tests/graph_test.py @@ -3,7 +3,7 @@ from sciline.graph import find_all_paths -def test_find_all_paths(): +def test_find_all_paths() -> None: graph = {"D": ["B", "C"], "C": ["A"], "B": ["A"]} assert find_all_paths(graph, "D", "A") == [["D", "B", "A"], ["D", "C", "A"]] assert find_all_paths(graph, "B", "C") == [] From d89ae317bb24e3059428e4c14900cab6cd0270bc Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 13:41:03 +0200 Subject: [PATCH 28/87] Test with multiple param tables --- tests/pipeline_with_index_test.py | 76 ++++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 2 deletions(-) diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 9684636f..1dea5621 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -39,13 +39,13 @@ def use_elem(x: Item((Label(int, 1),), float)) -> str: assert pl.compute(str) == "2.0" -def test_can_compute_map_of_param_values() -> None: +def test_can_compute_series_of_param_values() -> None: pl = sl.Pipeline() pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) assert pl.compute(sl.Series[int, float]) == {0: 1.0, 1: 2.0, 2: 3.0} -def test_can_compute_map_of_derived_values() -> None: +def test_can_compute_series_of_derived_values() -> None: def process(x: float) -> str: return str(x) @@ -54,6 +54,17 @@ def process(x: float) -> str: assert pl.compute(sl.Series[int, str]) == {0: "1.0", 1: "2.0", 2: "3.0"} +def test_explicit_index_of_param_table_is_forwarded_correctly() -> None: + def process(x: float) -> int: + return int(x) + + pl = sl.Pipeline([process]) + pl.set_param_table( + sl.ParamTable(str, {float: [1.0, 2.0, 3.0]}, index=['a', 'b', 'c']) + ) + assert pl.compute(sl.Series[str, int]) == {'a': 1, 'b': 2, 'c': 3} + + def test_can_gather_index() -> None: Sum = NewType("Sum", float) Name = NewType("Name", str) @@ -105,3 +116,64 @@ def join(x: Param1, y: Param2) -> float: pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 4, 9], Param2: [1, 2, 3]})) assert pl.compute(Sum) == 6 + + +def test_dependencies_on_different_param_tables_broadcast() -> None: + Row1 = NewType("Row1", int) + Row2 = NewType("Row2", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Product = NewType("Product", str) + + def gather(x: sl.Series[Row1, Param1], y: sl.Series[Row2, Param2]) -> Product: + broadcast = [[x_, y_] for x_ in x.values() for y_ in y.values()] + return str(broadcast) + + pl = sl.Pipeline([gather]) + pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) + pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) + assert pl.compute(Product) == "[[1, 4], [1, 5], [2, 4], [2, 5], [3, 4], [3, 5]]" + + +def test_dependency_on_other_param_table_in_parent_broadcasts_branch() -> None: + Row1 = NewType("Row1", int) + Row2 = NewType("Row2", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Summed2 = NewType("Summed2", int) + Product = NewType("Product", str) + + def gather2(x: Param1, y: sl.Series[Row2, Param2]) -> Summed2: + return Summed2(x * sum(y.values())) + + def gather1(x: sl.Series[Row1, Summed2]) -> Product: + return str(list(x.values())) + + pl = sl.Pipeline([gather1, gather2]) + pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) + pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) + assert pl.compute(Product) == "[9, 18, 27]" + + +def test_dependency_on_other_param_table_in_grandparent_broadcasts_branch() -> None: + Row1 = NewType("Row1", int) + Row2 = NewType("Row2", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Summed2 = NewType("Summed2", int) + Combined = NewType("Combined", int) + Product = NewType("Product", str) + + def gather2(x: sl.Series[Row2, Param2]) -> Summed2: + return Summed2(sum(x.values())) + + def combine(x: Param1, y: Summed2) -> Combined: + return Combined(x * y) + + def gather1(x: sl.Series[Row1, Combined]) -> Product: + return str(list(x.values())) + + pl = sl.Pipeline([gather1, gather2, combine]) + pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) + pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) + assert pl.compute(Product) == "[9, 18, 27]" From d62bb7ca86f8e10155d3af03b1dc623b932d5b4b Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 14:00:47 +0200 Subject: [PATCH 29/87] Test mixing generics with param tables --- src/sciline/pipeline.py | 6 ++-- tests/pipeline_with_index_test.py | 55 ++++++++++++++++++++++++++++--- 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index f05ae6bf..9ff2d306 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -250,7 +250,7 @@ def _get_provider( raise AmbiguousProvider("Multiple providers found for type", tp) raise UnsatisfiedRequirement("No provider found for type", tp) - def build(self, tp: Type[T], /) -> Graph: + def build(self, tp: Type[T], /, search_param_tables: bool = False) -> Graph: """ Return a dict of providers required for building the requested type `tp`. @@ -269,7 +269,7 @@ def build(self, tp: Type[T], /) -> Graph: stack: List[Union[Type[T], Label[T]]] = [tp] while stack: tp = stack.pop() - if tp in self._param_series: + if search_param_tables and tp in self._param_series: graph[tp] = (self._param_sentinel, (self._param_series[tp],)) continue if get_origin(tp) == Series: @@ -299,7 +299,7 @@ def _build_indexed_subgraph(self, tp: Type[Series[KeyType, ValueType]]) -> Graph graph: Graph = {} graph[tp] = (lambda *values: Series(dict(zip(index, values))), args) - subgraph = self.build(value_type) + subgraph = self.build(value_type, search_param_tables=True) path = find_nodes_in_paths(subgraph, value_type, index_name) for key, value in subgraph.items(): if key in path: diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 1dea5621..0db335ec 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import NewType +from typing import NewType, TypeVar import pytest @@ -125,11 +125,11 @@ def test_dependencies_on_different_param_tables_broadcast() -> None: Param2 = NewType("Param2", int) Product = NewType("Product", str) - def gather(x: sl.Series[Row1, Param1], y: sl.Series[Row2, Param2]) -> Product: + def gather_both(x: sl.Series[Row1, Param1], y: sl.Series[Row2, Param2]) -> Product: broadcast = [[x_, y_] for x_ in x.values() for y_ in y.values()] return str(broadcast) - pl = sl.Pipeline([gather]) + pl = sl.Pipeline([gather_both]) pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) assert pl.compute(Product) == "[[1, 4], [1, 5], [2, 4], [2, 5], [3, 4], [3, 5]]" @@ -143,13 +143,13 @@ def test_dependency_on_other_param_table_in_parent_broadcasts_branch() -> None: Summed2 = NewType("Summed2", int) Product = NewType("Product", str) - def gather2(x: Param1, y: sl.Series[Row2, Param2]) -> Summed2: + def gather2_and_combine(x: Param1, y: sl.Series[Row2, Param2]) -> Summed2: return Summed2(x * sum(y.values())) def gather1(x: sl.Series[Row1, Summed2]) -> Product: return str(list(x.values())) - pl = sl.Pipeline([gather1, gather2]) + pl = sl.Pipeline([gather1, gather2_and_combine]) pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) assert pl.compute(Product) == "[9, 18, 27]" @@ -177,3 +177,48 @@ def gather1(x: sl.Series[Row1, Combined]) -> Product: pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) assert pl.compute(Product) == "[9, 18, 27]" + + +def test_generic_providers_work_with_param_tables() -> None: + Param = TypeVar('Param') + Row = NewType("Row", int) + + class Str(sl.Scope[Param, str], str): + ... + + def parametrized(x: Param) -> Str[Param]: + return Str(f'{x}') + + def make_float() -> float: + return 1.5 + + pipeline = sl.Pipeline([make_float, parametrized]) + pipeline.set_param_table(sl.ParamTable(Row, {int: [1, 2, 3]})) + + assert pipeline.compute(Str[float]) == Str[float]('1.5') + with pytest.raises(sl.UnsatisfiedRequirement): + pipeline.compute(Str[int]) + assert pipeline.compute(sl.Series[Row, Str[int]]) == { + 0: Str[int]('1'), + 1: Str[int]('2'), + 2: Str[int]('3'), + } + + +def test_generic_provider_can_depend_on_param_series() -> None: + Param = TypeVar('Param') + Row = NewType("Row", int) + + class Str(sl.Scope[Param, str], str): + ... + + def parametrized_gather(x: sl.Series[Row, Param]) -> Str[Param]: + return Str(f'{list(x.values())}') + + pipeline = sl.Pipeline([parametrized_gather]) + pipeline.set_param_table( + sl.ParamTable(Row, {int: [1, 2, 3], float: [1.5, 2.5, 3.5]}) + ) + + assert pipeline.compute(Str[int]) == Str[int]('[1, 2, 3]') + assert pipeline.compute(Str[float]) == Str[float]('[1.5, 2.5, 3.5]') From 56cd560f064197d0d71324b4f7253385cf2f9106 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 14:18:36 +0200 Subject: [PATCH 30/87] More tests --- tests/pipeline_with_index_test.py | 41 +++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 0db335ec..0b3f6f37 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -222,3 +222,44 @@ def parametrized_gather(x: sl.Series[Row, Param]) -> Str[Param]: assert pipeline.compute(Str[int]) == Str[int]('[1, 2, 3]') assert pipeline.compute(Str[float]) == Str[float]('[1.5, 2.5, 3.5]') + + +def test_generic_provider_can_depend_on_derived_param_series() -> None: + T = TypeVar('T') + Row = NewType("Row", int) + + class Str(sl.Scope[T, str], str): + ... + + def use_param(x: int) -> float: + return x + 0.5 + + def parametrized_gather(x: sl.Series[Row, T]) -> Str[T]: + return Str(f'{list(x.values())}') + + pipeline = sl.Pipeline([parametrized_gather, use_param]) + pipeline.set_param_table(sl.ParamTable(Row, {int: [1, 2, 3]})) + + assert pipeline.compute(Str[float]) == Str[float]('[1.5, 2.5, 3.5]') + + +def test_params_in_table_can_be_generic() -> None: + T = TypeVar('T') + Row = NewType("Row", int) + + class Str(sl.Scope[T, str], str): + ... + + class Param(sl.Scope[T, str], str): + ... + + def parametrized_gather(x: sl.Series[Row, Param[T]]) -> Str[T]: + return Str(','.join(x.values())) + + pipeline = sl.Pipeline([parametrized_gather]) + pipeline.set_param_table( + sl.ParamTable(Row, {Param[int]: ["1", "2"], Param[float]: ["1.5", "2.5"]}) + ) + + assert pipeline.compute(Str[int]) == Str[int]('1,2') + assert pipeline.compute(Str[float]) == Str[float]('1.5,2.5') From 2c07aea152ebadc571be1645295b0b5970ec992a Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 15:00:40 +0200 Subject: [PATCH 31/87] Some cleanup --- src/sciline/pipeline.py | 16 +++++----------- tests/pipeline_with_index_test.py | 10 +++++++--- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 9ff2d306..843fdfa3 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -59,9 +59,7 @@ class Item(Generic[T]): tp: type -def _indexed_key(index_name: Any, i: int, value_name: Any) -> Union[Label, Item]: - if index_name == value_name: - return Label(index_name, i) +def _indexed_key(index_name: Any, i: int, value_name: Any) -> Item: label = Label(index_name, i) if isinstance(value_name, Item): return Item(value_name.label + (label,), value_name.tp) @@ -70,7 +68,7 @@ def _indexed_key(index_name: Any, i: int, value_name: Any) -> Union[Label, Item] Provider = Callable[..., Any] -Key = Union[type, Label, Item] +Key = Union[type, Item] def _is_compatible_type_tuple( @@ -149,7 +147,7 @@ def insert(self, provider: Provider, /) -> None: raise ValueError(f'Provider {provider} lacks type-hint for return value') self._set_provider(key, provider) - def __setitem__(self, key: Union[Type[T], Label[T]], param: T) -> None: + def __setitem__(self, key: Type[T], param: T) -> None: """ Provide a concrete value for a type. @@ -187,10 +185,6 @@ def __setitem__(self, key: Union[Type[T], Label[T]], param: T) -> None: ) self._set_provider(key, lambda: param) - @property - def param_tables(self) -> Dict[Key, ParamTable]: - return dict(self._param_tables) - def set_param_table(self, params: ParamTable) -> None: if params.row_dim in self._param_tables: raise ValueError(f'Parameter table for {params.row_dim} already set') @@ -208,7 +202,7 @@ def set_param_table(self, params: ParamTable) -> None: ) def _set_provider( - self, key: Union[Type[T], Label[T]], provider: Callable[..., T] + self, key: Union[Type[T], Item[T]], provider: Callable[..., T] ) -> None: # isinstance does not work here and types.NoneType available only in 3.10+ if key == type(None): # noqa: E721 @@ -225,7 +219,7 @@ def _set_provider( self._providers[key] = provider def _get_provider( - self, tp: Union[Type[T], Label[T], Item] + self, tp: Union[Type[T], Item] ) -> Tuple[Callable[..., T], Dict[TypeVar, Key]]: if (provider := self._providers.get(tp)) is not None: return provider, {} diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 0b3f6f37..a373f1a1 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -12,8 +12,10 @@ def test_set_param_table_raises_if_param_names_are_duplicate(): pl = sl.Pipeline() pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) with pytest.raises(ValueError): - pl.set_param_table(sl.ParamTable(str, {float: [1.0, 2.0, 3.0]})) - assert str not in pl.param_tables + pl.set_param_table(sl.ParamTable(str, {float: [4.0, 5.0, 6.0]})) + assert pl.compute(Item((Label(int, 1),), float)) == 2.0 + with pytest.raises(sl.UnsatisfiedRequirement): + pl.compute(Item((Label(str, 1),), float)) def test_set_param_table_raises_if_row_dim_is_duplicate(): @@ -21,7 +23,9 @@ def test_set_param_table_raises_if_row_dim_is_duplicate(): pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) with pytest.raises(ValueError): pl.set_param_table(sl.ParamTable(int, {str: ['a', 'b', 'c']})) - assert pl.param_tables[int] == sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}) + assert pl.compute(Item((Label(int, 1),), float)) == 2.0 + with pytest.raises(sl.UnsatisfiedRequirement): + pl.compute(Item((Label(int, 1),), str)) def test_can_get_elements_of_param_table() -> None: From d3b73d4be7026cca0b948609d950192103933bdc Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 15:02:11 +0200 Subject: [PATCH 32/87] Remove visualize from test --- tests/variadic_workflow_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py index b7f55a82..1e1acb97 100644 --- a/tests/variadic_workflow_test.py +++ b/tests/variadic_workflow_test.py @@ -70,7 +70,7 @@ def make_param() -> Param: assert graph.compute() == 2 graph = pipeline.get(Result) assert graph.compute() == 66.0 - graph.visualize().render('graph', format='png') + # graph.visualize().render('graph', format='png') # from dask.delayed import Delayed # dsk = {key: (value, *args) for key, (value, args) in graph._graph.items()} # Delayed(Result, dsk).visualize(filename='graph.png') From 5817b71a682a87aa2c450de556b1eb4e54b0f471 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 2 Aug 2023 16:20:34 +0200 Subject: [PATCH 33/87] Test actual diamond --- src/sciline/pipeline.py | 11 +++++------ src/sciline/series.py | 21 +++++++++++++++++++-- src/sciline/visualize.py | 2 +- tests/pipeline_with_index_test.py | 29 ++++++++++++++++++++++++++--- 4 files changed, 51 insertions(+), 12 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 843fdfa3..d755c67f 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -7,7 +7,6 @@ Any, Callable, Dict, - Generic, List, Optional, Tuple, @@ -48,14 +47,14 @@ class AmbiguousProvider(Exception): @dataclass(frozen=True) -class Label(Generic[T]): - tp: Type[T] +class Label: + tp: type index: int @dataclass(frozen=True) -class Item(Generic[T]): - label: Tuple[Label[T], ...] +class Item: + label: Tuple[Label, ...] tp: type @@ -291,7 +290,7 @@ def _build_indexed_subgraph(self, tp: Type[Series[KeyType, ValueType]]) -> Graph size = len(index) args = [_indexed_key(index_name, i, value_type) for i in range(size)] graph: Graph = {} - graph[tp] = (lambda *values: Series(dict(zip(index, values))), args) + graph[tp] = (lambda *values: Series(index_name, dict(zip(index, values))), args) subgraph = self.build(value_type, search_param_tables=True) path = find_nodes_in_paths(subgraph, value_type, index_name) diff --git a/src/sciline/series.py b/src/sciline/series.py index 0b2f75f9..986e7f13 100644 --- a/src/sciline/series.py +++ b/src/sciline/series.py @@ -3,16 +3,21 @@ from __future__ import annotations from collections import abc -from typing import Generic, Iterator, Mapping, TypeVar +from typing import Generic, Iterator, Mapping, Type, TypeVar Key = TypeVar('Key') Value = TypeVar('Value') class Series(abc.Mapping, Generic[Key, Value]): - def __init__(self, values: Mapping[Key, Value]) -> None: + def __init__(self, row_dim: Type[Key], values: Mapping[Key, Value]) -> None: + self._row_dim = row_dim self._map: Mapping[Key, Value] = values + @property + def row_dim(self) -> type: + return self._row_dim + def __contains__(self, item: object) -> bool: return item in self._map @@ -24,3 +29,15 @@ def __len__(self) -> int: def __getitem__(self, key: Key) -> Value: return self._map[key] + + def __repr__(self) -> str: + return f"Series(row_dim={self.row_dim}, {self._map})" + + def _repr_html_(self) -> str: + return ( + f"" + + "".join( + f"" for k, v in self._map.items() + ) + + "
{self.row_dim.__name__}Value
{k}{v}
" + ) diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index 268ebb6e..024988d2 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -35,7 +35,7 @@ def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph: # a dict if p_name.startswith('Pipeline._build_indexed_subgraph.'): for arg in args: - dot.edge(arg, ret) + dot.edge(arg, ret, style='dashed') else: dot.node(p, p_name, shape='ellipse') for arg in args: diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index a373f1a1..aaa0499f 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -108,9 +108,7 @@ def test_diamond_dependency_pulls_values_from_columns_in_same_param_table() -> N Param2 = NewType("Param2", int) Row = NewType("Run", int) - def gather( - x: sl.Series[Row, float], - ) -> Sum: + def gather(x: sl.Series[Row, float]) -> Sum: return Sum(sum(x.values())) def join(x: Param1, y: Param2) -> float: @@ -122,6 +120,31 @@ def join(x: Param1, y: Param2) -> float: assert pl.compute(Sum) == 6 +def test_diamond_dependency_on_same_column() -> None: + Sum = NewType("Sum", float) + Param = NewType("Param", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Row = NewType("Run", int) + + def gather(x: sl.Series[Row, float]) -> Sum: + return Sum(sum(x.values())) + + def to_param1(x: Param) -> Param1: + return Param1(x) + + def to_param2(x: Param) -> Param2: + return Param2(x) + + def join(x: Param1, y: Param2) -> float: + return x / y + + pl = sl.Pipeline([gather, join, to_param1, to_param2]) + pl.set_param_table(sl.ParamTable(Row, {Param: [1, 2, 3]})) + + assert pl.compute(Sum) == 3 + + def test_dependencies_on_different_param_tables_broadcast() -> None: Row1 = NewType("Row1", int) Row2 = NewType("Row2", int) From 0879ade09f5bee1e67820d35db7c1267de42d5b4 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 3 Aug 2023 08:36:24 +0200 Subject: [PATCH 34/87] Test getting nested series --- tests/pipeline_with_index_test.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index aaa0499f..55779af6 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -206,6 +206,30 @@ def gather1(x: sl.Series[Row1, Combined]) -> Product: assert pl.compute(Product) == "[9, 18, 27]" +def test_nested_dependencies_on_different_param_tables() -> None: + Row1 = NewType("Row1", int) + Row2 = NewType("Row2", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Combined = NewType("Combined", int) + + def combine(x: Param1, y: Param2) -> Combined: + return Combined(x * y) + + pl = sl.Pipeline([combine]) + pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) + pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) + assert pl.compute(sl.Series[Row1, sl.Series[Row2, Combined]]) == { + 0: {0: 4, 1: 5}, + 1: {0: 8, 1: 10}, + 2: {0: 12, 1: 15}, + } + assert pl.compute(sl.Series[Row2, sl.Series[Row1, Combined]]) == { + 0: {0: 4, 1: 8, 2: 12}, + 1: {0: 5, 1: 10, 2: 15}, + } + + def test_generic_providers_work_with_param_tables() -> None: Param = TypeVar('Param') Row = NewType("Row", int) From f7bf30df9154141585de22d31d021a3e636be888 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 3 Aug 2023 10:16:57 +0200 Subject: [PATCH 35/87] Manuel "groupby" --- tests/pipeline_with_index_test.py | 34 ++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 55779af6..47551cae 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -5,7 +5,7 @@ import pytest import sciline as sl -from sciline.pipeline import Item, Label +from sciline.pipeline import Dict, Item, Label def test_set_param_table_raises_if_param_names_are_duplicate(): @@ -230,6 +230,38 @@ def combine(x: Param1, y: Param2) -> Combined: } +def test_poor_mans_groupby_over_param_table() -> None: + Index = NewType("Index", str) + Label = NewType("Label", str) + Group = NewType("Group", str) + Param = NewType("Param", int) + GroupedParam = NewType("GroupedParam", int) + + from collections import defaultdict + + # Note that this creates a synchronization point, i.e., at this point all + # intermediate results must be computed, before branching out again to groups. + # I think there is probably no way around this, without explicit support + # from Pipeline. + def groupby_label( + x: sl.Series[Index, Param], labels: sl.Series[Index, Label] + ) -> Dict[Label, GroupedParam]: + groups = defaultdict(list) + for label, param in zip(labels.values(), x.values()): + groups[label].append(param) + return groups + + def get_group(group: Group, param: Dict[Label, GroupedParam]) -> GroupedParam: + return param[group] + + pl = sl.Pipeline([groupby_label, get_group]) + pl.set_param_table(sl.ParamTable(Index, {Param: [1, 2, 3], Label: ['a', 'a', 'b']})) + groups = ['a', 'b'] + pl.set_param_table(sl.ParamTable(Label, {Group: groups}, index=groups)) + result = pl.compute(sl.Series[Label, GroupedParam]) + assert result == {'a': [1, 2], 'b': [3]} + + def test_generic_providers_work_with_param_tables() -> None: Param = TypeVar('Param') Row = NewType("Row", int) From a68da1b0e5abb980c1724e1e9cecf052408590de Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 3 Aug 2023 10:17:51 +0200 Subject: [PATCH 36/87] Rename --- src/sciline/pipeline.py | 4 ++-- src/sciline/visualize.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index d755c67f..ed04f244 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -266,7 +266,7 @@ def build(self, tp: Type[T], /, search_param_tables: bool = False) -> Graph: graph[tp] = (self._param_sentinel, (self._param_series[tp],)) continue if get_origin(tp) == Series: - graph.update(self._build_indexed_subgraph(tp)) + graph.update(self._build_series(tp)) continue provider: Callable[..., T] provider, bound = self._get_provider(tp) @@ -282,7 +282,7 @@ def build(self, tp: Type[T], /, search_param_tables: bool = False) -> Graph: stack.append(arg) return graph - def _build_indexed_subgraph(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: + def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: index_name: Type[KeyType] value_type: Type[ValueType] index_name, value_type = get_args(tp) diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index 024988d2..fb0687f4 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -33,7 +33,7 @@ def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph: continue # Do not draw the internal provider gathering index-dependent results into # a dict - if p_name.startswith('Pipeline._build_indexed_subgraph.'): + if p_name.startswith('Pipeline._build_series.'): for arg in args: dot.edge(arg, ret, style='dashed') else: From ec65f77d6285ca75b3d569a915b7c315f84c71c4 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 3 Aug 2023 13:17:57 +0200 Subject: [PATCH 37/87] Support "groupby"-like operation --- src/sciline/pipeline.py | 44 +++++++++++++++++++++++++------ tests/pipeline_with_index_test.py | 25 ++++++++++++++++++ 2 files changed, 61 insertions(+), 8 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index ed04f244..72753ac6 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -2,6 +2,7 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from __future__ import annotations +from collections import defaultdict from dataclasses import dataclass from typing import ( Any, @@ -283,28 +284,55 @@ def build(self, tp: Type[T], /, search_param_tables: bool = False) -> Graph: return graph def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: - index_name: Type[KeyType] + label_name: Type[KeyType] value_type: Type[ValueType] - index_name, value_type = get_args(tp) - index = self._param_tables[index_name].index + label_name, value_type = get_args(tp) + if label_name in self._param_tables: + index = self._param_tables[label_name].index + index_name = None + elif (index_name := self._param_series.get(label_name)) is not None: + indices = self._param_tables[index_name].index + labels = self._param_tables[index_name][label_name] + groups = defaultdict(list) + for i, label in zip(indices, labels): + groups[label].append(i) + index = list(groups) + size = len(index) - args = [_indexed_key(index_name, i, value_type) for i in range(size)] + args = [_indexed_key(label_name, i, value_type) for i in range(size)] graph: Graph = {} - graph[tp] = (lambda *values: Series(index_name, dict(zip(index, values))), args) + graph[tp] = (lambda *values: Series(label_name, dict(zip(index, values))), args) subgraph = self.build(value_type, search_param_tables=True) - path = find_nodes_in_paths(subgraph, value_type, index_name) + if index_name is None: + path_end = label_name + else: + for key in subgraph: + if get_origin(key) == Series and get_args(key)[0] == index_name: + path_end = key + path = find_nodes_in_paths(subgraph, value_type, path_end) for key, value in subgraph.items(): if key in path: + + def _in_group(arg, group_index: int): + if len(arg.label) != 1: + raise ValueError( + f'Cannot build series with multi-index label {arg.label}' + ) + return arg.label[0].index in groups[index[group_index]] + + in_group = _in_group if key == path_end else lambda *_: True + for i in range(size): provider, args = value - subkey = _indexed_key(index_name, i, key) + subkey = _indexed_key(label_name, i, key) if provider == self._param_sentinel: provider, _ = self._get_provider(subkey) args = () args_with_index = tuple( - _indexed_key(index_name, i, arg) if arg in path else arg + _indexed_key(label_name, i, arg) if arg in path else arg for arg in args + if in_group(arg, i) ) graph[subkey] = (provider, args_with_index) else: diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 47551cae..31c148a9 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -262,6 +262,31 @@ def get_group(group: Group, param: Dict[Label, GroupedParam]) -> GroupedParam: assert result == {'a': [1, 2], 'b': [3]} +def test_groupby_over_param_table() -> None: + Index = NewType("Index", int) + Label = NewType("Label", str) + Param = NewType("Param", int) + ProcessedParam = NewType("ProcessedParam", int) + SummedGroup = NewType("SummedGroup", int) + ProcessedGroup = NewType("ProcessedGroup", int) + + def process_param(x: Param) -> ProcessedParam: + return ProcessedParam(x + 1) + + def sum_group(group: sl.Series[Index, ProcessedParam]) -> SummedGroup: + return SummedGroup(sum(group.values())) + + def process(x: SummedGroup) -> ProcessedGroup: + return ProcessedGroup(2 * x) + + params = sl.ParamTable(Index, {Param: [1, 2, 3], Label: ['a', 'a', 'b']}) + pl = sl.Pipeline([process_param, sum_group, process]) + pl.set_param_table(params) + + graph = pl.get(sl.Series[Label, ProcessedGroup]) + assert graph.compute() == {'a': 10, 'b': 8} + + def test_generic_providers_work_with_param_tables() -> None: Param = TypeVar('Param') Row = NewType("Row", int) From 43163380167da54341e5cd3ebd3b340fa1e2981b Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 3 Aug 2023 13:38:04 +0200 Subject: [PATCH 38/87] Fix use of wrong labels --- src/sciline/pipeline.py | 41 +++++++++++++++++-------------- tests/pipeline_with_index_test.py | 7 ++++-- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 72753ac6..f685d027 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -287,14 +287,13 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: label_name: Type[KeyType] value_type: Type[ValueType] label_name, value_type = get_args(tp) - if label_name in self._param_tables: - index = self._param_tables[label_name].index + if (params := self._param_tables.get(label_name)) is not None: + index = params.index index_name = None elif (index_name := self._param_series.get(label_name)) is not None: - indices = self._param_tables[index_name].index labels = self._param_tables[index_name][label_name] groups = defaultdict(list) - for i, label in zip(indices, labels): + for i, label in enumerate(labels): groups[label].append(i) index = list(groups) @@ -303,26 +302,21 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: graph: Graph = {} graph[tp] = (lambda *values: Series(label_name, dict(zip(index, values))), args) + def _in_group(arg, group_index: int): + if len(arg.label) != 1: + raise ValueError(f'Cannot group with multi-index label {arg.label}') + return arg.label[0].index in groups[index[group_index]] + subgraph = self.build(value_type, search_param_tables=True) - if index_name is None: - path_end = label_name - else: - for key in subgraph: - if get_origin(key) == Series and get_args(key)[0] == index_name: - path_end = key + path_end = ( + label_name + if index_name is None + else self._find_grouping_node(index_name, subgraph) + ) path = find_nodes_in_paths(subgraph, value_type, path_end) for key, value in subgraph.items(): if key in path: - - def _in_group(arg, group_index: int): - if len(arg.label) != 1: - raise ValueError( - f'Cannot build series with multi-index label {arg.label}' - ) - return arg.label[0].index in groups[index[group_index]] - in_group = _in_group if key == path_end else lambda *_: True - for i in range(size): provider, args = value subkey = _indexed_key(label_name, i, key) @@ -339,6 +333,15 @@ def _in_group(arg, group_index: int): graph[key] = value return graph + def _find_grouping_node(self, index_name: type, subgraph: Graph) -> type: + ends = [] + for key in subgraph: + if get_origin(key) == Series and get_args(key)[0] == index_name: + ends.append(key) + if len(ends) == 1: + return ends[0] + raise ValueError(f"Could not find unique grouping node, found {ends}") + @overload def compute(self, tp: Type[T]) -> T: ... diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 31c148a9..7677f74b 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -262,7 +262,8 @@ def get_group(group: Group, param: Dict[Label, GroupedParam]) -> GroupedParam: assert result == {'a': [1, 2], 'b': [3]} -def test_groupby_over_param_table() -> None: +@pytest.mark.parametrize("index", [None, [4, 5, 6]]) +def test_groupby_over_param_table(index) -> None: Index = NewType("Index", int) Label = NewType("Label", str) Param = NewType("Param", int) @@ -279,7 +280,9 @@ def sum_group(group: sl.Series[Index, ProcessedParam]) -> SummedGroup: def process(x: SummedGroup) -> ProcessedGroup: return ProcessedGroup(2 * x) - params = sl.ParamTable(Index, {Param: [1, 2, 3], Label: ['a', 'a', 'b']}) + params = sl.ParamTable( + Index, {Param: [1, 2, 3], Label: ['a', 'a', 'b']}, index=index + ) pl = sl.Pipeline([process_param, sum_group, process]) pl.set_param_table(params) From 46728197bf22bacab645d8aa0b77a33486f2041a Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 3 Aug 2023 14:14:14 +0200 Subject: [PATCH 39/87] Attempt cleanup --- src/sciline/pipeline.py | 47 +++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index f685d027..f3f00946 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -287,36 +287,43 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: label_name: Type[KeyType] value_type: Type[ValueType] label_name, value_type = get_args(tp) + subgraph = self.build(value_type, search_param_tables=True) if (params := self._param_tables.get(label_name)) is not None: index = params.index - index_name = None - elif (index_name := self._param_series.get(label_name)) is not None: - labels = self._param_tables[index_name][label_name] - groups = defaultdict(list) + path = find_nodes_in_paths(subgraph, value_type, label_name) + + def in_group(*_) -> bool: + return True + + elif ((index_name := self._param_series.get(label_name)) is not None) and ( + (labels := self._param_tables[index_name].get(label_name)) is not None + ): + index = defaultdict(list) for i, label in enumerate(labels): - groups[label].append(i) - index = list(groups) + index[label].append(i) + groups = list(index.values()) + end = self._find_grouping_node(index_name, subgraph) + path = find_nodes_in_paths(subgraph, value_type, end) + + def in_group(arg, group_index: int, key: type): + if key != end: + return True + if len(arg.label) != 1: + raise ValueError(f'Cannot group with multi-index label {arg.label}') + return arg.label[0].index in groups[group_index] + + else: + raise UnsatisfiedRequirement( + f'No parameter table found for label {label_name}' + ) size = len(index) args = [_indexed_key(label_name, i, value_type) for i in range(size)] graph: Graph = {} graph[tp] = (lambda *values: Series(label_name, dict(zip(index, values))), args) - def _in_group(arg, group_index: int): - if len(arg.label) != 1: - raise ValueError(f'Cannot group with multi-index label {arg.label}') - return arg.label[0].index in groups[index[group_index]] - - subgraph = self.build(value_type, search_param_tables=True) - path_end = ( - label_name - if index_name is None - else self._find_grouping_node(index_name, subgraph) - ) - path = find_nodes_in_paths(subgraph, value_type, path_end) for key, value in subgraph.items(): if key in path: - in_group = _in_group if key == path_end else lambda *_: True for i in range(size): provider, args = value subkey = _indexed_key(label_name, i, key) @@ -326,7 +333,7 @@ def _in_group(arg, group_index: int): args_with_index = tuple( _indexed_key(label_name, i, arg) if arg in path else arg for arg in args - if in_group(arg, i) + if in_group(arg, i, key) ) graph[subkey] = (provider, args_with_index) else: From 9f336507bb2d9852fa3d13e5b3d4a71ae788b677 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 3 Aug 2023 14:41:25 +0200 Subject: [PATCH 40/87] Extract Grouper --- src/sciline/pipeline.py | 69 +++++++++++++++++++------------ tests/pipeline_with_index_test.py | 9 ++++ 2 files changed, 51 insertions(+), 27 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index f3f00946..5b699d9c 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -8,7 +8,9 @@ Any, Callable, Dict, + Iterator, List, + Literal, Optional, Tuple, Type, @@ -102,6 +104,36 @@ def _bind_free_typevars(tp: TypeVar | Key, bound: Dict[TypeVar, Key]) -> Key: return tp +class Grouper: + def __init__(self, grouping_node: Optional[type] = None, index=None) -> None: + self.grouping_node = grouping_node + if grouping_node is None: + self.index = index + else: + self.index = defaultdict(list) + for i, label in enumerate(index): + self.index[label].append(i) + self.groups = list(self.index.values()) + + @property + def size(self) -> int: + return len(self.index) + + def __iter__(self) -> Iterator: + return iter(self.index) + + def __call__(self, key: Any) -> Any: + return self.in_group if key == self.grouping_node else self.yes + + def in_group(self, arg, group_index: int) -> bool: + if len(arg.label) != 1: + raise ValueError(f'Cannot group with multi-index label {arg.label}') + return arg.label[0].index in self.groups[group_index] + + def yes(self, *_: Any) -> Literal[True]: + return True + + class Pipeline: """A container for providers that can be assembled into a task graph.""" @@ -289,42 +321,25 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: label_name, value_type = get_args(tp) subgraph = self.build(value_type, search_param_tables=True) if (params := self._param_tables.get(label_name)) is not None: - index = params.index path = find_nodes_in_paths(subgraph, value_type, label_name) - - def in_group(*_) -> bool: - return True - + grouper = Grouper(index=params.index) elif ((index_name := self._param_series.get(label_name)) is not None) and ( (labels := self._param_tables[index_name].get(label_name)) is not None ): - index = defaultdict(list) - for i, label in enumerate(labels): - index[label].append(i) - groups = list(index.values()) - end = self._find_grouping_node(index_name, subgraph) - path = find_nodes_in_paths(subgraph, value_type, end) - - def in_group(arg, group_index: int, key: type): - if key != end: - return True - if len(arg.label) != 1: - raise ValueError(f'Cannot group with multi-index label {arg.label}') - return arg.label[0].index in groups[group_index] - + grouping_node = self._find_grouping_node(index_name, subgraph) + path = find_nodes_in_paths(subgraph, value_type, grouping_node) + grouper = Grouper(index=labels, grouping_node=grouping_node) else: - raise UnsatisfiedRequirement( - f'No parameter table found for label {label_name}' - ) + raise KeyError(f'No parameter table found for label {label_name}') - size = len(index) - args = [_indexed_key(label_name, i, value_type) for i in range(size)] + args = [_indexed_key(label_name, i, value_type) for i in range(grouper.size)] graph: Graph = {} - graph[tp] = (lambda *values: Series(label_name, dict(zip(index, values))), args) + graph[tp] = (lambda *vals: Series(label_name, dict(zip(grouper, vals))), args) for key, value in subgraph.items(): if key in path: - for i in range(size): + in_group = grouper(key) + for i in range(grouper.size): provider, args = value subkey = _indexed_key(label_name, i, key) if provider == self._param_sentinel: @@ -333,7 +348,7 @@ def in_group(arg, group_index: int, key: type): args_with_index = tuple( _indexed_key(label_name, i, arg) if arg in path else arg for arg in args - if in_group(arg, i, key) + if in_group(arg, i) ) graph[subkey] = (provider, args_with_index) else: diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 7677f74b..a34b16f4 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -28,6 +28,15 @@ def test_set_param_table_raises_if_row_dim_is_duplicate(): pl.compute(Item((Label(int, 1),), str)) +def test_non_unique_index_raises(): + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[1, 1, 2])) + # TODO Should we raise in ParamTable, or change Series? + assert pl.compute(sl.Series[int, float]) == {1: 2.0, 2: 3.0} + with pytest.raises(ValueError): + pl.compute(sl.Series[int, float]) + + def test_can_get_elements_of_param_table() -> None: pl = sl.Pipeline() pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) From dcb6aee1feb2275496779bc95118234215eefa1e Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 4 Aug 2023 06:18:15 +0200 Subject: [PATCH 41/87] Avoid clashing name --- tests/pipeline_with_index_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index a34b16f4..b6639597 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -274,7 +274,7 @@ def get_group(group: Group, param: Dict[Label, GroupedParam]) -> GroupedParam: @pytest.mark.parametrize("index", [None, [4, 5, 6]]) def test_groupby_over_param_table(index) -> None: Index = NewType("Index", int) - Label = NewType("Label", str) + Name = NewType("Name", str) Param = NewType("Param", int) ProcessedParam = NewType("ProcessedParam", int) SummedGroup = NewType("SummedGroup", int) @@ -290,12 +290,12 @@ def process(x: SummedGroup) -> ProcessedGroup: return ProcessedGroup(2 * x) params = sl.ParamTable( - Index, {Param: [1, 2, 3], Label: ['a', 'a', 'b']}, index=index + Index, {Param: [1, 2, 3], Name: ['a', 'a', 'b']}, index=index ) pl = sl.Pipeline([process_param, sum_group, process]) pl.set_param_table(params) - graph = pl.get(sl.Series[Label, ProcessedGroup]) + graph = pl.get(sl.Series[Name, ProcessedGroup]) assert graph.compute() == {'a': 10, 'b': 8} From 762097fd9013789ac8316871265eebc615dafcdc Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 4 Aug 2023 07:53:10 +0200 Subject: [PATCH 42/87] Some more complex tests, uncovering issues with group indices --- src/sciline/pipeline.py | 2 + tests/pipeline_with_index_test.py | 73 +++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 5b699d9c..8cda5206 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -334,6 +334,8 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: args = [_indexed_key(label_name, i, value_type) for i in range(grouper.size)] graph: Graph = {} + # This is wrong: We duplicate and select a group below, + # but zip from the beginning. Need to zip with correct group! graph[tp] = (lambda *vals: Series(label_name, dict(zip(grouper, vals))), args) for key, value in subgraph.items(): diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index b6639597..9f2d541a 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -28,6 +28,7 @@ def test_set_param_table_raises_if_row_dim_is_duplicate(): pl.compute(Item((Label(int, 1),), str)) +@pytest.mark.skip(reason="TODO: Should we allow this?") def test_non_unique_index_raises(): pl = sl.Pipeline() pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[1, 1, 2])) @@ -271,6 +272,55 @@ def get_group(group: Group, param: Dict[Label, GroupedParam]) -> GroupedParam: assert result == {'a': [1, 2], 'b': [3]} +def test_can_groupby_by_requesting_series_of_series() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]})) + assert pl.compute(sl.Series[Param1, sl.Series[Row, Param2]]) == { + 1: {0: 4, 1: 5}, + 3: {2: 6}, + } + + +def test_multi_level_groupby_raises_with_params_from_same_table() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Param3 = NewType("Param3", int) + + pl = sl.Pipeline() + pl.set_param_table( + sl.ParamTable( + Row, {Param1: [1, 1, 1, 3], Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]} + ) + ) + with pytest.raises(ValueError, match='Could not find unique grouping node'): + pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) + + +def test_multi_level_groupby_with_params_from_different_table() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Param3 = NewType("Param3", int) + + pl = sl.Pipeline() + grouping1 = sl.ParamTable(Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}) + grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}) + pl.set_param_table(grouping1) + pl.set_param_table(grouping2) + # pl.compute(sl.Series[Param2, sl.Series[Row, Param3]]) + assert pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) == { + 1: {4: {0: 7}, 5: {0: 8, 1: 9}}, + 3: {6: {0: 10}}, + } + # with pytest.raises(ValueError, match='Could not find unique grouping node'): + # pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) + + @pytest.mark.parametrize("index", [None, [4, 5, 6]]) def test_groupby_over_param_table(index) -> None: Index = NewType("Index", int) @@ -299,6 +349,29 @@ def process(x: SummedGroup) -> ProcessedGroup: assert graph.compute() == {'a': 10, 'b': 8} +def test_requesting_series_index_that_is_not_in_param_table_raises() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]})) + + with pytest.raises(KeyError): + pl.compute(sl.Series[int, Param2]) + + +def test_requesting_series_index_that_is_a_param_raises_if_not_grouping() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]})) + with pytest.raises(ValueError, match='Could not find unique grouping node'): + pl.compute(sl.Series[Param1, Param2]) + + def test_generic_providers_work_with_param_tables() -> None: Param = TypeVar('Param') Row = NewType("Row", int) From d06c9db2f68eb8f01cdf2f56b30f54c4f61ff9e9 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 4 Aug 2023 13:17:30 +0200 Subject: [PATCH 43/87] Fix index zip in groupby ops --- src/sciline/pipeline.py | 36 +++++++++++++++++++++------ tests/pipeline_with_index_test.py | 41 +++++++++++++++++++++++++++---- 2 files changed, 65 insertions(+), 12 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 8cda5206..37734fc4 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -125,6 +125,11 @@ def __iter__(self) -> Iterator: def __call__(self, key: Any) -> Any: return self.in_group if key == self.grouping_node else self.yes + def get_grouping(self, key: Any, group: int) -> Optional[List[int]]: + if key != self.grouping_node: + return None + return self.groups[group] + def in_group(self, arg, group_index: int) -> bool: if len(arg.label) != 1: raise ValueError(f'Cannot group with multi-index label {arg.label}') @@ -134,6 +139,21 @@ def yes(self, *_: Any) -> Literal[True]: return True +class SeriesProducer: + def __init__(self, labels: List[Any], row_dim: type) -> None: + self._labels = labels + self._row_dim = row_dim + + def __call__(self, *vals: Any) -> Series: + return Series(self._row_dim, dict(zip(self._labels, vals))) + + def restrict(self, indices: Optional[List[int]]) -> SeriesProducer: + if indices is None: + return self + labels = [self._labels[i] for i in indices] + return SeriesProducer(labels, self._row_dim) + + class Pipeline: """A container for providers that can be assembled into a task graph.""" @@ -320,12 +340,14 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: value_type: Type[ValueType] label_name, value_type = get_args(tp) subgraph = self.build(value_type, search_param_tables=True) - if (params := self._param_tables.get(label_name)) is not None: + if ( + label_name not in self._param_series + and (params := self._param_tables.get(label_name)) is not None + ): path = find_nodes_in_paths(subgraph, value_type, label_name) grouper = Grouper(index=params.index) - elif ((index_name := self._param_series.get(label_name)) is not None) and ( - (labels := self._param_tables[index_name].get(label_name)) is not None - ): + elif (index_name := self._param_series.get(label_name)) is not None: + labels = self._param_tables[index_name][label_name] grouping_node = self._find_grouping_node(index_name, subgraph) path = find_nodes_in_paths(subgraph, value_type, grouping_node) grouper = Grouper(index=labels, grouping_node=grouping_node) @@ -334,9 +356,7 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: args = [_indexed_key(label_name, i, value_type) for i in range(grouper.size)] graph: Graph = {} - # This is wrong: We duplicate and select a group below, - # but zip from the beginning. Need to zip with correct group! - graph[tp] = (lambda *vals: Series(label_name, dict(zip(grouper, vals))), args) + graph[tp] = (SeriesProducer(list(grouper), label_name), args) for key, value in subgraph.items(): if key in path: @@ -352,6 +372,8 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: for arg in args if in_group(arg, i) ) + if isinstance(provider, SeriesProducer): + provider = provider.restrict(grouper.get_grouping(key, i)) graph[subkey] = (provider, args_with_index) else: graph[key] = value diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 9f2d541a..cf7a07c3 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -285,6 +285,21 @@ def test_can_groupby_by_requesting_series_of_series() -> None: } +def test_groupby_by_requesting_series_of_series_preserves_indices() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + + pl = sl.Pipeline() + pl.set_param_table( + sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]}, index=[11, 12, 13]) + ) + assert pl.compute(sl.Series[Param1, sl.Series[Row, Param2]]) == { + 1: {11: 4, 12: 5}, + 3: {13: 6}, + } + + def test_multi_level_groupby_raises_with_params_from_same_table() -> None: Row = NewType("Row", int) Param1 = NewType("Param1", int) @@ -312,13 +327,29 @@ def test_multi_level_groupby_with_params_from_different_table() -> None: grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}) pl.set_param_table(grouping1) pl.set_param_table(grouping2) - # pl.compute(sl.Series[Param2, sl.Series[Row, Param3]]) assert pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) == { - 1: {4: {0: 7}, 5: {0: 8, 1: 9}}, - 3: {6: {0: 10}}, + 1: {4: {0: 7}, 5: {1: 8, 2: 9}}, + 3: {6: {3: 10}}, + } + + +def test_multi_level_groupby_with_params_from_different_table_preserves_index() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Param3 = NewType("Param3", int) + + pl = sl.Pipeline() + grouping1 = sl.ParamTable( + Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400] + ) + grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}) + pl.set_param_table(grouping1) + pl.set_param_table(grouping2) + assert pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) == { + 1: {4: {100: 7}, 5: {200: 8, 300: 9}}, + 3: {6: {400: 10}}, } - # with pytest.raises(ValueError, match='Could not find unique grouping node'): - # pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) @pytest.mark.parametrize("index", [None, [4, 5, 6]]) From 14d9e9ce25401065fe3c716329136c0dd8453487 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 7 Aug 2023 08:11:41 +0200 Subject: [PATCH 44/87] Use actual indices for keys --- src/sciline/pipeline.py | 51 ++++++++------- tests/pipeline_with_index_test.py | 104 +++++++++++++++++++----------- 2 files changed, 94 insertions(+), 61 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 37734fc4..905db33a 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -105,19 +105,16 @@ def _bind_free_typevars(tp: TypeVar | Key, bound: Dict[TypeVar, Key]) -> Key: class Grouper: - def __init__(self, grouping_node: Optional[type] = None, index=None) -> None: + def __init__( + self, grouping_node: Optional[type] = None, index=None, labels=None + ) -> None: self.grouping_node = grouping_node if grouping_node is None: self.index = index else: self.index = defaultdict(list) - for i, label in enumerate(index): - self.index[label].append(i) - self.groups = list(self.index.values()) - - @property - def size(self) -> int: - return len(self.index) + for idx, label in zip(index, labels): + self.index[label].append(idx) def __iter__(self) -> Iterator: return iter(self.index) @@ -128,12 +125,12 @@ def __call__(self, key: Any) -> Any: def get_grouping(self, key: Any, group: int) -> Optional[List[int]]: if key != self.grouping_node: return None - return self.groups[group] + return self.index[group] - def in_group(self, arg, group_index: int) -> bool: + def in_group(self, arg, group_index: Any) -> bool: if len(arg.label) != 1: raise ValueError(f'Cannot group with multi-index label {arg.label}') - return arg.label[0].index in self.groups[group_index] + return arg.label[0].index in self.index[group_index] def yes(self, *_: Any) -> Literal[True]: return True @@ -147,10 +144,13 @@ def __init__(self, labels: List[Any], row_dim: type) -> None: def __call__(self, *vals: Any) -> Series: return Series(self._row_dim, dict(zip(self._labels, vals))) - def restrict(self, indices: Optional[List[int]]) -> SeriesProducer: - if indices is None: + def restrict(self, labels: Optional[List[int]]) -> SeriesProducer: + if labels is None: return self - labels = [self._labels[i] for i in indices] + if set(labels) - set(self._labels): + raise ValueError(f'{labels} is not a subset of {self._labels}') + # Ensure that labels are in the same order as in the original series + labels = [label for label in self._labels if label in labels] return SeriesProducer(labels, self._row_dim) @@ -247,9 +247,9 @@ def set_param_table(self, params: ParamTable) -> None: for param_name in params: self._param_series[param_name] = params.row_dim for param_name, values in params.items(): - for i, label in enumerate(values): + for index, label in zip(params.index, values): self._set_provider( - Item((Label(tp=params.row_dim, index=i),), param_name), + Item((Label(tp=params.row_dim, index=index),), param_name), lambda label=label: label, ) @@ -347,33 +347,36 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: path = find_nodes_in_paths(subgraph, value_type, label_name) grouper = Grouper(index=params.index) elif (index_name := self._param_series.get(label_name)) is not None: - labels = self._param_tables[index_name][label_name] + params = self._param_tables[index_name] + labels = params[label_name] grouping_node = self._find_grouping_node(index_name, subgraph) path = find_nodes_in_paths(subgraph, value_type, grouping_node) - grouper = Grouper(index=labels, grouping_node=grouping_node) + grouper = Grouper( + index=params.index, labels=labels, grouping_node=grouping_node + ) else: raise KeyError(f'No parameter table found for label {label_name}') - args = [_indexed_key(label_name, i, value_type) for i in range(grouper.size)] + args = [_indexed_key(label_name, index, value_type) for index in grouper] graph: Graph = {} graph[tp] = (SeriesProducer(list(grouper), label_name), args) for key, value in subgraph.items(): if key in path: in_group = grouper(key) - for i in range(grouper.size): + for index in grouper: provider, args = value - subkey = _indexed_key(label_name, i, key) + subkey = _indexed_key(label_name, index, key) if provider == self._param_sentinel: provider, _ = self._get_provider(subkey) args = () args_with_index = tuple( - _indexed_key(label_name, i, arg) if arg in path else arg + _indexed_key(label_name, index, arg) if arg in path else arg for arg in args - if in_group(arg, i) + if in_group(arg, index) ) if isinstance(provider, SeriesProducer): - provider = provider.restrict(grouper.get_grouping(key, i)) + provider = provider.restrict(grouper.get_grouping(key, index)) graph[subkey] = (provider, args_with_index) else: graph[key] = value diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index cf7a07c3..2b4f5fda 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -5,7 +5,7 @@ import pytest import sciline as sl -from sciline.pipeline import Dict, Item, Label +from sciline.pipeline import Item, Label def test_set_param_table_raises_if_param_names_are_duplicate(): @@ -44,6 +44,12 @@ def test_can_get_elements_of_param_table() -> None: assert pl.compute(Item((Label(int, 1),), float)) == 2.0 +def test_can_get_elements_of_param_table_with_explicit_index() -> None: + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[11, 12, 13])) + assert pl.compute(Item((Label(int, 12),), float)) == 2.0 + + def test_can_depend_on_elements_of_param_table() -> None: def use_elem(x: Item((Label(int, 1),), float)) -> str: return str(x) @@ -240,38 +246,6 @@ def combine(x: Param1, y: Param2) -> Combined: } -def test_poor_mans_groupby_over_param_table() -> None: - Index = NewType("Index", str) - Label = NewType("Label", str) - Group = NewType("Group", str) - Param = NewType("Param", int) - GroupedParam = NewType("GroupedParam", int) - - from collections import defaultdict - - # Note that this creates a synchronization point, i.e., at this point all - # intermediate results must be computed, before branching out again to groups. - # I think there is probably no way around this, without explicit support - # from Pipeline. - def groupby_label( - x: sl.Series[Index, Param], labels: sl.Series[Index, Label] - ) -> Dict[Label, GroupedParam]: - groups = defaultdict(list) - for label, param in zip(labels.values(), x.values()): - groups[label].append(param) - return groups - - def get_group(group: Group, param: Dict[Label, GroupedParam]) -> GroupedParam: - return param[group] - - pl = sl.Pipeline([groupby_label, get_group]) - pl.set_param_table(sl.ParamTable(Index, {Param: [1, 2, 3], Label: ['a', 'a', 'b']})) - groups = ['a', 'b'] - pl.set_param_table(sl.ParamTable(Label, {Group: groups}, index=groups)) - result = pl.compute(sl.Series[Label, GroupedParam]) - assert result == {'a': [1, 2], 'b': [3]} - - def test_can_groupby_by_requesting_series_of_series() -> None: Row = NewType("Row", int) Param1 = NewType("Param1", int) @@ -323,13 +297,32 @@ def test_multi_level_groupby_with_params_from_different_table() -> None: Param3 = NewType("Param3", int) pl = sl.Pipeline() - grouping1 = sl.ParamTable(Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}) + grouping1 = sl.ParamTable(Row, {Param2: [0, 1, 1, 2], Param3: [7, 8, 9, 10]}) + # We are not providing an explicit index here, so this only happens to work because + # the values of Param2 match range(2). grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}) pl.set_param_table(grouping1) pl.set_param_table(grouping2) assert pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) == { - 1: {4: {0: 7}, 5: {1: 8, 2: 9}}, - 3: {6: {3: 10}}, + 1: {0: {0: 7}, 1: {1: 8, 2: 9}}, + 3: {2: {3: 10}}, + } + + +def test_multi_level_groupby_with_params_from_different_table_can_select() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Param3 = NewType("Param3", int) + + pl = sl.Pipeline() + grouping1 = sl.ParamTable(Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}) + # Note the missing index "6" here. + grouping2 = sl.ParamTable(Param2, {Param1: [1, 1]}, index=[4, 5]) + pl.set_param_table(grouping1) + pl.set_param_table(grouping2) + assert pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) == { + 1: {4: {0: 7}, 5: {1: 8, 2: 9}} } @@ -343,7 +336,7 @@ def test_multi_level_groupby_with_params_from_different_table_preserves_index() grouping1 = sl.ParamTable( Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400] ) - grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}) + grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=[4, 5, 6]) pl.set_param_table(grouping1) pl.set_param_table(grouping2) assert pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) == { @@ -352,6 +345,43 @@ def test_multi_level_groupby_with_params_from_different_table_preserves_index() } +def test_multi_level_groupby_with_params_from_different_table_can_reorder() -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Param3 = NewType("Param3", int) + + pl = sl.Pipeline() + grouping1 = sl.ParamTable( + Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400] + ) + grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=[6, 5, 4]) + pl.set_param_table(grouping1) + pl.set_param_table(grouping2) + assert pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) == { + 1: {6: {400: 10}, 5: {200: 8, 300: 9}}, + 3: {4: {100: 7}}, + } + + +@pytest.mark.parametrize("index", [None, [4, 5, 7]]) +def test_multi_level_groupby_raises_on_index_mismatch(index) -> None: + Row = NewType("Row", int) + Param1 = NewType("Param1", int) + Param2 = NewType("Param2", int) + Param3 = NewType("Param3", int) + + pl = sl.Pipeline() + grouping1 = sl.ParamTable( + Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400] + ) + grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=index) + pl.set_param_table(grouping1) + pl.set_param_table(grouping2) + with pytest.raises(ValueError): + pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) + + @pytest.mark.parametrize("index", [None, [4, 5, 6]]) def test_groupby_over_param_table(index) -> None: Index = NewType("Index", int) From 22e9a353805bc1daaaba9b3f13f3d340d31c233c Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 7 Aug 2023 08:23:35 +0200 Subject: [PATCH 45/87] Require unique index in ParamTable [no ci] --- src/sciline/param_table.py | 11 +++++++---- tests/param_table_test.py | 5 +++++ tests/pipeline_with_index_test.py | 10 ---------- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/sciline/param_table.py b/src/sciline/param_table.py index 6ee3fd04..850577d1 100644 --- a/src/sciline/param_table.py +++ b/src/sciline/param_table.py @@ -20,10 +20,13 @@ def __init__( f"Columns in param table must all have same size, got {sizes}" ) size = sizes.pop() - if index is not None and len(index) != size: - raise ValueError( - f"Index and columns must have same size, got {len(index)} and {size}" - ) + if index is not None: + if len(index) != size: + raise ValueError( + f"Index length not matching columns, got {len(index)} and {size}" + ) + if len(set(index)) != len(index): + raise ValueError(f"Index must be unique, got {index}") self._row_dim = row_dim self._columns = columns self._index = index or list(range(size)) diff --git a/tests/param_table_test.py b/tests/param_table_test.py index 6bcfd482..aba84204 100644 --- a/tests/param_table_test.py +++ b/tests/param_table_test.py @@ -20,6 +20,11 @@ def test_raises_with_inconsistent_index_length() -> None: sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0]}, index=[1, 2, 3]) +def test_raises_with_non_unique_index(): + with pytest.raises(ValueError): + sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[1, 1, 2]) + + def test_contains_includes_all_columns() -> None: pt = sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0, 3.0]}) assert int in pt diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 2b4f5fda..02f306b7 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -28,16 +28,6 @@ def test_set_param_table_raises_if_row_dim_is_duplicate(): pl.compute(Item((Label(int, 1),), str)) -@pytest.mark.skip(reason="TODO: Should we allow this?") -def test_non_unique_index_raises(): - pl = sl.Pipeline() - pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[1, 1, 2])) - # TODO Should we raise in ParamTable, or change Series? - assert pl.compute(sl.Series[int, float]) == {1: 2.0, 2: 3.0} - with pytest.raises(ValueError): - pl.compute(sl.Series[int, float]) - - def test_can_get_elements_of_param_table() -> None: pl = sl.Pipeline() pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) From 3288d6b83d6fb10316c107a838bb554a767dc724 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 7 Aug 2023 08:38:19 +0200 Subject: [PATCH 46/87] Update visualize [no ci] --- src/sciline/visualize.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index fb0687f4..a81884f9 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -33,7 +33,7 @@ def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph: continue # Do not draw the internal provider gathering index-dependent results into # a dict - if p_name.startswith('Pipeline._build_series.'): + if p_name == 'SeriesProducer': for arg in args: dot.edge(arg, ret, style='dashed') else: @@ -44,10 +44,16 @@ def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph: return dot +def _qualname(obj: Any) -> str: + return ( + obj.__qualname__ if hasattr(obj, '__qualname__') else obj.__class__.__qualname__ + ) + + def _format_graph(graph: Graph) -> Dict[str, Tuple[str, List[str], str]]: return { _format_provider(provider, ret): ( - provider.__qualname__, + _qualname(provider), [_format_type(a) for a in args], _format_type(ret), ) @@ -56,7 +62,7 @@ def _format_graph(graph: Graph) -> Dict[str, Tuple[str, List[str], str]]: def _format_provider(provider: Callable[..., Any], ret: type) -> str: - return f'{provider.__qualname__}_{_format_type(ret)}' + return f'{_qualname(provider)}_{_format_type(ret)}' def _extract_type_and_labels(key: Union[Item, Label, type]) -> Tuple[type, List[type]]: From 4be8c6c42a7908c84baf4f3afc9952291ad270ef Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 7 Aug 2023 09:15:29 +0200 Subject: [PATCH 47/87] Split Grouper to make design better for mypy --- src/sciline/pipeline.py | 57 +++++++++++++++++++++++++-------------- tests/param_table_test.py | 2 +- 2 files changed, 38 insertions(+), 21 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 905db33a..ad5d4252 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -8,6 +8,7 @@ Any, Callable, Dict, + Iterable, Iterator, List, Literal, @@ -104,36 +105,52 @@ def _bind_free_typevars(tp: TypeVar | Key, bound: Dict[TypeVar, Key]) -> Key: return tp +def _yes(*_: Any) -> Literal[True]: + return True + + +class NoGrouping: + def __init__(self, index: Iterable[Any]) -> None: + self._index = index + + def __iter__(self) -> Iterator[Any]: + return iter(self._index) + + def __call__(self, key: Any) -> Callable[..., bool]: + return _yes + + def get_grouping(self, key: Any, group: int) -> None: + return None + + class Grouper: def __init__( - self, grouping_node: Optional[type] = None, index=None, labels=None + self, + *, + grouping_node: type, + index: Iterable[Any], + labels: Iterable[Any], ) -> None: self.grouping_node = grouping_node - if grouping_node is None: - self.index = index - else: - self.index = defaultdict(list) - for idx, label in zip(index, labels): - self.index[label].append(idx) + self._index = defaultdict(list) + for idx, label in zip(index, labels): + self._index[label].append(idx) - def __iter__(self) -> Iterator: - return iter(self.index) + def __iter__(self) -> Iterator[Any]: + return iter(self._index) def __call__(self, key: Any) -> Any: - return self.in_group if key == self.grouping_node else self.yes + return self.in_group if key == self.grouping_node else _yes def get_grouping(self, key: Any, group: int) -> Optional[List[int]]: if key != self.grouping_node: return None - return self.index[group] + return self._index[group] def in_group(self, arg, group_index: Any) -> bool: if len(arg.label) != 1: raise ValueError(f'Cannot group with multi-index label {arg.label}') - return arg.label[0].index in self.index[group_index] - - def yes(self, *_: Any) -> Literal[True]: - return True + return arg.label[0].index in self._index[group_index] class SeriesProducer: @@ -254,7 +271,7 @@ def set_param_table(self, params: ParamTable) -> None: ) def _set_provider( - self, key: Union[Type[T], Item[T]], provider: Callable[..., T] + self, key: Union[Type[T], Item], provider: Callable[..., T] ) -> None: # isinstance does not work here and types.NoneType available only in 3.10+ if key == type(None): # noqa: E721 @@ -312,7 +329,7 @@ def build(self, tp: Type[T], /, search_param_tables: bool = False) -> Graph: Type to build the graph for. """ graph: Graph = {} - stack: List[Union[Type[T], Label[T]]] = [tp] + stack: List[Union[Type[T], Item]] = [tp] while stack: tp = stack.pop() if search_param_tables and tp in self._param_series: @@ -345,7 +362,7 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: and (params := self._param_tables.get(label_name)) is not None ): path = find_nodes_in_paths(subgraph, value_type, label_name) - grouper = Grouper(index=params.index) + grouper = NoGrouping(index=params.index) elif (index_name := self._param_series.get(label_name)) is not None: params = self._param_tables[index_name] labels = params[label_name] @@ -357,7 +374,7 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: else: raise KeyError(f'No parameter table found for label {label_name}') - args = [_indexed_key(label_name, index, value_type) for index in grouper] + args = tuple(_indexed_key(label_name, index, value_type) for index in grouper) graph: Graph = {} graph[tp] = (SeriesProducer(list(grouper), label_name), args) @@ -382,7 +399,7 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: graph[key] = value return graph - def _find_grouping_node(self, index_name: type, subgraph: Graph) -> type: + def _find_grouping_node(self, index_name: Key, subgraph: Graph) -> type: ends = [] for key in subgraph: if get_origin(key) == Series and get_args(key)[0] == index_name: diff --git a/tests/param_table_test.py b/tests/param_table_test.py index aba84204..4a57ad18 100644 --- a/tests/param_table_test.py +++ b/tests/param_table_test.py @@ -20,7 +20,7 @@ def test_raises_with_inconsistent_index_length() -> None: sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0]}, index=[1, 2, 3]) -def test_raises_with_non_unique_index(): +def test_raises_with_non_unique_index() -> None: with pytest.raises(ValueError): sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[1, 1, 2]) From 9672c292183ca01b097f1f3ffc121d359790b6ab Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 7 Aug 2023 12:41:11 +0200 Subject: [PATCH 48/87] Add param table docs [no ci] --- docs/user-guide/getting-started.ipynb | 6 +- docs/user-guide/parameter-tables.ipynb | 438 +++++++++++++++++++++++++ 2 files changed, 443 insertions(+), 1 deletion(-) create mode 100644 docs/user-guide/parameter-tables.ipynb diff --git a/docs/user-guide/getting-started.ipynb b/docs/user-guide/getting-started.ipynb index f9d08cc4..b00c5c61 100644 --- a/docs/user-guide/getting-started.ipynb +++ b/docs/user-guide/getting-started.ipynb @@ -101,6 +101,9 @@ "from typing import NewType\n", "import sciline\n", "\n", + "_fake_filesytem = {'data.txt': [1, 2, float('nan'), 3]}\n", + "\n", + "\n", "# 1. Define domain types\n", "\n", "Filename = NewType('Filename', str)\n", @@ -115,7 +118,8 @@ "\n", "def load(filename: Filename) -> RawData:\n", " \"\"\"Load the data from the filename.\"\"\"\n", - " return {'data': [1, 2, float('nan'), 3], 'meta': {'filename': filename}}\n", + " data = _fake_filesytem[filename]\n", + " return {'data': data, 'meta': {'filename': filename}}\n", "\n", "\n", "def clean(raw_data: RawData) -> CleanedData:\n", diff --git a/docs/user-guide/parameter-tables.ipynb b/docs/user-guide/parameter-tables.ipynb new file mode 100644 index 00000000..416aa527 --- /dev/null +++ b/docs/user-guide/parameter-tables.ipynb @@ -0,0 +1,438 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parameter Tables\n", + "\n", + "## Overview\n", + "\n", + "Parameter tables provide a mechanism for repeating parts of or all of a computation with different values for one or more parameters.\n", + "This allows for a variety of use cases, similar to *map*, *reduce*, and *groupby* operations in other systems.\n", + "We illustrate each of these in the follow three chapters." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Computing results for series of parameters\n", + "\n", + "This chapter illustrates how to implement *map* operations with Sciline.\n", + "\n", + "Starting with the model workflow introduced in [Getting Started](getting-started.ipynb), we can replace the fixed `Filename` parameter with a series of filenames listed in a [ParamTable](../generated/classes/ParamTable.rst):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import NewType\n", + "import sciline\n", + "\n", + "_fake_filesytem = {\n", + " 'file102.txt': [1, 2, float('nan'), 3],\n", + " 'file103.txt': [1, 2, 3, 4],\n", + " 'file104.txt': [1, 2, 3, 4, 5],\n", + "}\n", + "\n", + "# 1. Define domain types\n", + "\n", + "Filename = NewType('Filename', str)\n", + "RawData = NewType('RawData', dict)\n", + "CleanedData = NewType('CleanedData', list)\n", + "ScaleFactor = NewType('ScaleFactor', float)\n", + "Result = NewType('Result', float)\n", + "\n", + "\n", + "# 2. Define providers\n", + "\n", + "\n", + "def load(filename: Filename) -> RawData:\n", + " \"\"\"Load the data from the filename.\"\"\"\n", + "\n", + " data = _fake_filesytem[filename]\n", + " return {'data': data, 'meta': {'filename': filename}}\n", + "\n", + "\n", + "def clean(raw_data: RawData) -> CleanedData:\n", + " \"\"\"Clean the data, removing NaNs.\"\"\"\n", + " import math\n", + "\n", + " return [x for x in raw_data['data'] if not math.isnan(x)]\n", + "\n", + "\n", + "def process(data: CleanedData, param: ScaleFactor) -> Result:\n", + " \"\"\"Process the data, multiplying the sum by the scale factor.\"\"\"\n", + " return sum(data) * param\n", + "\n", + "\n", + "# 3. Create pipeline\n", + "\n", + "# 3.a Providers and normal parameters\n", + "providers = [load, clean, process]\n", + "params = {ScaleFactor: 2.0}\n", + "\n", + "# 3.b Parameter table\n", + "RunID = NewType('RunID', int)\n", + "run_ids = [102, 103, 104]\n", + "filenames = [f'file{i}.txt' for i in run_ids]\n", + "param_table = sciline.ParamTable(RunID, {Filename: filenames}, index=run_ids)\n", + "\n", + "# 3.c Setup pipeline\n", + "pipeline = sciline.Pipeline(providers, params=params)\n", + "pipeline.set_param_table(param_table)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note how steps 1.) and 2.) are identical to those from the example without parameter table.\n", + "We can now compute `Result` for each index in the parameter table:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.compute(sciline.Series[RunID, Result])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "`sciline.Series` is a special `dict`-like type that signals to Sciline that the values of the series are based on values from one or more columns of a parameter table.\n", + "The parameter table is identified using the first argument to `Series`, in this case `RunID`.\n", + "The second argument specifies the result to be computed." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also visualize the task graph for computing the series of `Result` values:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.visualize(sciline.Series[RunID, Result])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Sciline uses a compact representation of the task graph.\n", + "Instead of drawing every intermediate result and provider for each parameter, we represent parameter-dependent results as \"3D box\" nodes, with the parameter index name (the row dimension of the parameter table) given in parenthesis.\n", + "For example, above `Filename(RunID)` and `Result(RunID)` represent series of nodes, each for a different run ID.\n", + "\n", + "Note that for computing the results they are handled independently, i.e., the task graph has independent branches for each run ID.\n", + "This is important for parallelization, as it allows to run the tasks for different run IDs in parallel and avoids excessive memory use for intemediate results.\n", + "\n", + "The dashed arrow indicates and internal transformation that gathers result from each branch and combines them into a single output, here `Series[RunID, Result]`.\n", + "Note how this transitions from a \"3D box\" to a \"2D box\" as the series of keys in the graph is reduced into a single key." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Combining intermediate results from series of parameters\n", + "\n", + "This chapter illustrates how to implement *reduce* operations with Sciline.\n", + "\n", + "Instead of requesting a series of results as above, we can also build pipelines with providers that depend on such series.\n", + "We can create a new pipeline, or extend the existing one by inserting a new provider:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MergedResult = NewType('MergedResult', float)\n", + "\n", + "\n", + "def merge_runs(runs: sciline.Series[RunID, Result]) -> MergedResult:\n", + " return MergedResult(sum(runs.values()))\n", + "\n", + "\n", + "pipeline.insert(merge_runs)\n", + "graph = pipeline.get(MergedResult)\n", + "graph.visualize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that this is identical to the example in the previous section, except for the last two nodes in the graph.\n", + "The computation now returns a single result:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "graph.compute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is useful if we need to continue computation after gathering results without setting up a second pipeline." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Grouping intermediate results based on secondary parameters\n", + "\n", + "This chapter illustrates how to implement *groupby* operations with Sciline.\n", + "\n", + "Continuing from the examples for *map* and *reduce*, we can introduce a secondary parameter in the table, such as the material of the sample:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Material = NewType('Material', str)\n", + "\n", + "# 3.a Providers and normal parameters\n", + "providers = [load, clean, process, merge_runs]\n", + "params = {ScaleFactor: 2.0}\n", + "\n", + "# 3.b Parameter table\n", + "run_ids = [102, 103, 104]\n", + "sample = ['diamond', 'graphite', 'graphite']\n", + "filenames = [f'file{i}.txt' for i in run_ids]\n", + "param_table = sciline.ParamTable(\n", + " RunID, {Filename: filenames, Material: sample}, index=run_ids\n", + ")\n", + "\n", + "# 3.c Setup pipeline\n", + "pipeline = sciline.Pipeline(providers, params=params)\n", + "pipeline.set_param_table(param_table)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now compute `MergedResult` for a series of \"materials\":" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.compute(sciline.Series[Material, MergedResult])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The computation looks as show below.\n", + "Note how the initial steps of the computation depend on the `RunID` parameter, while later steps depend on `Material`:\n", + "The files for each run ID have been grouped by their material and then merged:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.visualize(sciline.Series[Material, MergedResult])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## More examples" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Combining multiple parameters from same table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sciline as sl\n", + "\n", + "Sum = NewType(\"Sum\", float)\n", + "Param1 = NewType(\"Param1\", int)\n", + "Param2 = NewType(\"Param2\", int)\n", + "Row = NewType(\"Run\", int)\n", + "\n", + "\n", + "def gather(\n", + " x: sl.Series[Row, float],\n", + ") -> Sum:\n", + " return Sum(sum(x.values()))\n", + "\n", + "\n", + "def product(x: Param1, y: Param2) -> float:\n", + " return x / y\n", + "\n", + "\n", + "pl = sl.Pipeline([gather, product])\n", + "pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 4, 9], Param2: [1, 2, 3]}))\n", + "\n", + "pl.visualize(Sum)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Diamond graphs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Sum = NewType(\"Sum\", float)\n", + "Param = NewType(\"Param\", int)\n", + "Param1 = NewType(\"Param1\", int)\n", + "Param2 = NewType(\"Param2\", int)\n", + "Row = NewType(\"Run\", int)\n", + "\n", + "\n", + "def gather(x: sl.Series[Row, float]) -> Sum:\n", + " return Sum(sum(x.values()))\n", + "\n", + "\n", + "def to_param1(x: Param) -> Param1:\n", + " return Param1(x)\n", + "\n", + "\n", + "def to_param2(x: Param) -> Param2:\n", + " return Param2(x)\n", + "\n", + "\n", + "def product(x: Param1, y: Param2) -> float:\n", + " return x / y\n", + "\n", + "\n", + "pl = sl.Pipeline([gather, product, to_param1, to_param2])\n", + "pl.set_param_table(sl.ParamTable(Row, {Param: [1, 2, 3]}))\n", + "pl.visualize(Sum)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Combining parameters from different tables" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sciline as sl\n", + "\n", + "List1 = NewType(\"List1\", float)\n", + "List2 = NewType(\"List2\", float)\n", + "Param1 = NewType(\"Param1\", int)\n", + "Param2 = NewType(\"Param2\", int)\n", + "Row1 = NewType(\"Row1\", int)\n", + "Row2 = NewType(\"Row2\", int)\n", + "\n", + "\n", + "def gather1(x: sl.Series[Row1, float]) -> List1:\n", + " return List1(list(x.values()))\n", + "\n", + "\n", + "def gather2(x: sl.Series[Row2, List1]) -> List2:\n", + " return List2(list(x.values()))\n", + "\n", + "\n", + "def product(x: Param1, y: Param2) -> float:\n", + " return x * y\n", + "\n", + "\n", + "pl = sl.Pipeline([gather1, gather2, product])\n", + "pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 4, 9]}))\n", + "pl.set_param_table(sl.ParamTable(Row2, {Param2: [1, 2, 3]}))\n", + "\n", + "pl.visualize(List2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note how itermediates such as `float(Row1, Row2)` depend on two parameters, i.e., we are dealing with a 2-D array of branches in the graph." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pl.compute(List2)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 5e82bea21c1bb4ab6b632515d6eb38fcc66f63d3 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 7 Aug 2023 12:43:14 +0200 Subject: [PATCH 49/87] Cleanup --- docs/user-guide/parameter-tables.ipynb | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/docs/user-guide/parameter-tables.ipynb b/docs/user-guide/parameter-tables.ipynb index 416aa527..09566cdd 100644 --- a/docs/user-guide/parameter-tables.ipynb +++ b/docs/user-guide/parameter-tables.ipynb @@ -314,6 +314,15 @@ "pl.visualize(Sum)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pl.compute(Sum)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -347,7 +356,7 @@ "\n", "\n", "def product(x: Param1, y: Param2) -> float:\n", - " return x / y\n", + " return x * y\n", "\n", "\n", "pl = sl.Pipeline([gather, product, to_param1, to_param2])\n", @@ -401,7 +410,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Note how itermediates such as `float(Row1, Row2)` depend on two parameters, i.e., we are dealing with a 2-D array of branches in the graph." + "Note how intermediates such as `float(Row1, Row2)` depend on two parameters, i.e., we are dealing with a 2-D array of branches in the graph." ] }, { From d9fdc160a07edf986e61ea8437e5e063fff33ad3 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 7 Aug 2023 14:06:11 +0200 Subject: [PATCH 50/87] Some mypy fixes --- src/sciline/graph.py | 4 +-- src/sciline/pipeline.py | 66 ++++++++++++++++++++++++++++------------- 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/src/sciline/graph.py b/src/sciline/graph.py index e24a7a49..2c03b866 100644 --- a/src/sciline/graph.py +++ b/src/sciline/graph.py @@ -1,4 +1,4 @@ -from typing import Callable, Collection, List, Mapping, Tuple, TypeVar +from typing import Any, Callable, Collection, List, Mapping, Tuple, TypeVar T = TypeVar("T") @@ -17,7 +17,7 @@ def find_all_paths(graph: Mapping[T, Collection[T]], start: T, end: T) -> List[L def find_nodes_in_paths( - graph: Mapping[T, Tuple[Callable[..., T], Collection[T]]], start: T, end: T + graph: Mapping[T, Tuple[Callable[..., Any], Collection[T]]], start: T, end: T ) -> List[T]: # 0 is the provider, 1 is the args dependencies = {k: v[1] for k, v in graph.items()} diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index ad5d4252..fb754d8e 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -8,11 +8,13 @@ Any, Callable, Dict, + Generic, Iterable, Iterator, List, Literal, Optional, + Protocol, Tuple, Type, TypeVar, @@ -28,12 +30,14 @@ from .domain import Scope from .graph import find_nodes_in_paths from .param_table import ParamTable -from .scheduler import Graph, Scheduler +from .scheduler import Scheduler from .series import Series T = TypeVar('T') KeyType = TypeVar('KeyType') ValueType = TypeVar('ValueType') +IndexType = TypeVar('IndexType') +LabelType = TypeVar('LabelType') class UnsatisfiedRequirement(Exception): @@ -72,6 +76,10 @@ def _indexed_key(index_name: Any, i: int, value_name: Any) -> Item: Provider = Callable[..., Any] Key = Union[type, Item] +Graph = Dict[ + Key, + Tuple[Callable[..., Any], Tuple[Key, ...]], +] def _is_compatible_type_tuple( @@ -109,11 +117,22 @@ def _yes(*_: Any) -> Literal[True]: return True -class NoGrouping: - def __init__(self, index: Iterable[Any]) -> None: +class Grouper(Protocol): + def __iter__(self) -> Iterator[Any]: + ... + + def __call__(self, key: Any) -> Callable[..., bool]: + ... + + def get_grouping(self, key: Any, group: Any) -> Any: + ... + + +class NoGrouping(Generic[IndexType]): + def __init__(self, index: Iterable[IndexType]) -> None: self._index = index - def __iter__(self) -> Iterator[Any]: + def __iter__(self) -> Iterator[IndexType]: return iter(self._index) def __call__(self, key: Any) -> Callable[..., bool]: @@ -123,45 +142,47 @@ def get_grouping(self, key: Any, group: int) -> None: return None -class Grouper: +class GroupBy(Generic[IndexType, LabelType]): def __init__( self, *, grouping_node: type, - index: Iterable[Any], - labels: Iterable[Any], + index: Iterable[IndexType], + labels: Iterable[LabelType], ) -> None: self.grouping_node = grouping_node - self._index = defaultdict(list) + self._index: Dict[LabelType, List[IndexType]] = defaultdict(list) for idx, label in zip(index, labels): self._index[label].append(idx) - def __iter__(self) -> Iterator[Any]: + def __iter__(self) -> Iterator[LabelType]: return iter(self._index) def __call__(self, key: Any) -> Any: return self.in_group if key == self.grouping_node else _yes - def get_grouping(self, key: Any, group: int) -> Optional[List[int]]: + def get_grouping(self, key: Any, group: LabelType) -> Optional[List[IndexType]]: if key != self.grouping_node: return None return self._index[group] - def in_group(self, arg, group_index: Any) -> bool: + def in_group(self, arg: Item, group: LabelType) -> bool: if len(arg.label) != 1: raise ValueError(f'Cannot group with multi-index label {arg.label}') - return arg.label[0].index in self._index[group_index] + return arg.label[0].index in self._index[group] -class SeriesProducer: - def __init__(self, labels: List[Any], row_dim: type) -> None: +class SeriesProducer(Generic[KeyType, ValueType]): + def __init__(self, labels: Iterable[KeyType], row_dim: type) -> None: self._labels = labels self._row_dim = row_dim - def __call__(self, *vals: Any) -> Series: + def __call__(self, *vals: ValueType) -> Series[KeyType, ValueType]: return Series(self._row_dim, dict(zip(self._labels, vals))) - def restrict(self, labels: Optional[List[int]]) -> SeriesProducer: + def restrict( + self, labels: Optional[Iterable[KeyType]] + ) -> SeriesProducer[KeyType, ValueType]: if labels is None: return self if set(labels) - set(self._labels): @@ -313,7 +334,9 @@ def _get_provider( raise AmbiguousProvider("Multiple providers found for type", tp) raise UnsatisfiedRequirement("No provider found for type", tp) - def build(self, tp: Type[T], /, search_param_tables: bool = False) -> Graph: + def build( + self, tp: Union[Type[T], Item], /, search_param_tables: bool = False + ) -> Graph: """ Return a dict of providers required for building the requested type `tp`. @@ -357,6 +380,7 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: value_type: Type[ValueType] label_name, value_type = get_args(tp) subgraph = self.build(value_type, search_param_tables=True) + grouper: Grouper if ( label_name not in self._param_series and (params := self._param_tables.get(label_name)) is not None @@ -368,15 +392,17 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: labels = params[label_name] grouping_node = self._find_grouping_node(index_name, subgraph) path = find_nodes_in_paths(subgraph, value_type, grouping_node) - grouper = Grouper( + grouper = GroupBy( index=params.index, labels=labels, grouping_node=grouping_node ) else: raise KeyError(f'No parameter table found for label {label_name}') - args = tuple(_indexed_key(label_name, index, value_type) for index in grouper) graph: Graph = {} - graph[tp] = (SeriesProducer(list(grouper), label_name), args) + graph[tp] = ( + SeriesProducer(list(grouper), label_name), + tuple(_indexed_key(label_name, index, value_type) for index in grouper), + ) for key, value in subgraph.items(): if key in path: From a1259bc40b23e2e747c7e434e971598c605bbbcd Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 7 Aug 2023 14:38:14 +0200 Subject: [PATCH 51/87] More mypy --- src/sciline/param_table.py | 2 +- src/sciline/pipeline.py | 14 ++-- src/sciline/series.py | 9 ++- src/sciline/visualize.py | 2 +- tests/pipeline_with_index_test.py | 130 ++++++++++++++++++++---------- 5 files changed, 104 insertions(+), 53 deletions(-) diff --git a/src/sciline/param_table.py b/src/sciline/param_table.py index 850577d1..b800fa43 100644 --- a/src/sciline/param_table.py +++ b/src/sciline/param_table.py @@ -6,7 +6,7 @@ from typing import Any, Collection, Dict, Optional -class ParamTable(abc.Mapping): +class ParamTable(abc.Mapping[type, Collection[Any]]): def __init__( self, row_dim: type, diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index fb754d8e..8eae63f1 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -61,9 +61,9 @@ class Label: @dataclass(frozen=True) -class Item: +class Item(Generic[T]): label: Tuple[Label, ...] - tp: type + tp: Type[T] def _indexed_key(index_name: Any, i: int, value_name: Any) -> Item: @@ -201,7 +201,7 @@ def __init__( self, providers: Optional[List[Provider]] = None, *, - params: Optional[Dict[Key, Any]] = None, + params: Optional[Dict[type, Any]] = None, ): """ Setup a Pipeline from a list providers @@ -359,7 +359,7 @@ def build( graph[tp] = (self._param_sentinel, (self._param_series[tp],)) continue if get_origin(tp) == Series: - graph.update(self._build_series(tp)) + graph.update(self._build_series(tp)) # type: ignore[arg-type] continue provider: Callable[..., T] provider, bound = self._get_provider(tp) @@ -442,7 +442,11 @@ def compute(self, tp: Type[T]) -> T: def compute(self, tp: Tuple[Type[T], ...]) -> Tuple[T, ...]: ... - def compute(self, tp: type | Tuple[type, ...]) -> Any: + @overload + def compute(self, tp: Item[T]) -> T: + ... + + def compute(self, tp: type | Tuple[type, ...] | Item[T]) -> Any: """ Compute result for the given keys. diff --git a/src/sciline/series.py b/src/sciline/series.py index 986e7f13..c88f1e70 100644 --- a/src/sciline/series.py +++ b/src/sciline/series.py @@ -3,13 +3,13 @@ from __future__ import annotations from collections import abc -from typing import Generic, Iterator, Mapping, Type, TypeVar +from typing import Iterator, Mapping, Type, TypeVar Key = TypeVar('Key') Value = TypeVar('Value') -class Series(abc.Mapping, Generic[Key, Value]): +class Series(abc.Mapping[Key, Value]): def __init__(self, row_dim: Type[Key], values: Mapping[Key, Value]) -> None: self._row_dim = row_dim self._map: Mapping[Key, Value] = values @@ -41,3 +41,8 @@ def _repr_html_(self) -> str: ) + "" ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Series): + return NotImplemented + return self.row_dim == other.row_dim and self._map == other._map diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index a81884f9..699776b5 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -44,7 +44,7 @@ def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph: return dot -def _qualname(obj: Any) -> str: +def _qualname(obj: Any) -> Any: return ( obj.__qualname__ if hasattr(obj, '__qualname__') else obj.__class__.__qualname__ ) diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 02f306b7..a7313212 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -8,7 +8,7 @@ from sciline.pipeline import Item, Label -def test_set_param_table_raises_if_param_names_are_duplicate(): +def test_set_param_table_raises_if_param_names_are_duplicate() -> None: pl = sl.Pipeline() pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) with pytest.raises(ValueError): @@ -18,7 +18,7 @@ def test_set_param_table_raises_if_param_names_are_duplicate(): pl.compute(Item((Label(str, 1),), float)) -def test_set_param_table_raises_if_row_dim_is_duplicate(): +def test_set_param_table_raises_if_row_dim_is_duplicate() -> None: pl = sl.Pipeline() pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) with pytest.raises(ValueError): @@ -41,7 +41,8 @@ def test_can_get_elements_of_param_table_with_explicit_index() -> None: def test_can_depend_on_elements_of_param_table() -> None: - def use_elem(x: Item((Label(int, 1),), float)) -> str: + # This is not a valid type annotation, notsure why it works with get_type_hints + def use_elem(x: Item((Label(int, 1),), float)) -> str: # type: ignore[valid-type] return str(x) pl = sl.Pipeline([use_elem]) @@ -52,7 +53,7 @@ def use_elem(x: Item((Label(int, 1),), float)) -> str: def test_can_compute_series_of_param_values() -> None: pl = sl.Pipeline() pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - assert pl.compute(sl.Series[int, float]) == {0: 1.0, 1: 2.0, 2: 3.0} + assert pl.compute(sl.Series[int, float]) == sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0}) def test_can_compute_series_of_derived_values() -> None: @@ -61,7 +62,9 @@ def process(x: float) -> str: pl = sl.Pipeline([process]) pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - assert pl.compute(sl.Series[int, str]) == {0: "1.0", 1: "2.0", 2: "3.0"} + assert pl.compute(sl.Series[int, str]) == sl.Series( + int, {0: "1.0", 1: "2.0", 2: "3.0"} + ) def test_explicit_index_of_param_table_is_forwarded_correctly() -> None: @@ -72,7 +75,7 @@ def process(x: float) -> int: pl.set_param_table( sl.ParamTable(str, {float: [1.0, 2.0, 3.0]}, index=['a', 'b', 'c']) ) - assert pl.compute(sl.Series[str, int]) == {'a': 1, 'b': 2, 'c': 3} + assert pl.compute(sl.Series[str, int]) == sl.Series(str, {'a': 1, 'b': 2, 'c': 3}) def test_can_gather_index() -> None: @@ -225,15 +228,21 @@ def combine(x: Param1, y: Param2) -> Combined: pl = sl.Pipeline([combine]) pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) - assert pl.compute(sl.Series[Row1, sl.Series[Row2, Combined]]) == { - 0: {0: 4, 1: 5}, - 1: {0: 8, 1: 10}, - 2: {0: 12, 1: 15}, - } - assert pl.compute(sl.Series[Row2, sl.Series[Row1, Combined]]) == { - 0: {0: 4, 1: 8, 2: 12}, - 1: {0: 5, 1: 10, 2: 15}, - } + assert pl.compute(sl.Series[Row1, sl.Series[Row2, Combined]]) == sl.Series( + Row1, + { + 0: sl.Series(Row2, {0: 4, 1: 5}), + 1: sl.Series(Row2, {0: 8, 1: 10}), + 2: sl.Series(Row2, {0: 12, 1: 15}), + }, + ) + assert pl.compute(sl.Series[Row2, sl.Series[Row1, Combined]]) == sl.Series( + Row2, + { + 0: sl.Series(Row1, {0: 4, 1: 8, 2: 12}), + 1: sl.Series(Row1, {0: 5, 1: 10, 2: 15}), + }, + ) def test_can_groupby_by_requesting_series_of_series() -> None: @@ -243,10 +252,11 @@ def test_can_groupby_by_requesting_series_of_series() -> None: pl = sl.Pipeline() pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]})) - assert pl.compute(sl.Series[Param1, sl.Series[Row, Param2]]) == { - 1: {0: 4, 1: 5}, - 3: {2: 6}, - } + expected = sl.Series( + Param1, + {1: sl.Series(Row, {0: 4, 1: 5}), 3: sl.Series(Row, {2: 6})}, + ) + assert pl.compute(sl.Series[Param1, sl.Series[Row, Param2]]) == expected def test_groupby_by_requesting_series_of_series_preserves_indices() -> None: @@ -258,10 +268,9 @@ def test_groupby_by_requesting_series_of_series_preserves_indices() -> None: pl.set_param_table( sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]}, index=[11, 12, 13]) ) - assert pl.compute(sl.Series[Param1, sl.Series[Row, Param2]]) == { - 1: {11: 4, 12: 5}, - 3: {13: 6}, - } + assert pl.compute(sl.Series[Param1, sl.Series[Row, Param2]]) == sl.Series( + Param1, {1: sl.Series(Row, {11: 4, 12: 5}), 3: sl.Series(Row, {13: 6})} + ) def test_multi_level_groupby_raises_with_params_from_same_table() -> None: @@ -293,10 +302,17 @@ def test_multi_level_groupby_with_params_from_different_table() -> None: grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}) pl.set_param_table(grouping1) pl.set_param_table(grouping2) - assert pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) == { - 1: {0: {0: 7}, 1: {1: 8, 2: 9}}, - 3: {2: {3: 10}}, - } + assert pl.compute( + sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]] + ) == sl.Series( + Param1, + { + 1: sl.Series( + Param2, {0: sl.Series(Row, {0: 7}), 1: sl.Series(Row, {1: 8, 2: 9})} + ), + 3: sl.Series(Param2, {2: sl.Series(Row, {3: 10})}), + }, + ) def test_multi_level_groupby_with_params_from_different_table_can_select() -> None: @@ -311,9 +327,16 @@ def test_multi_level_groupby_with_params_from_different_table_can_select() -> No grouping2 = sl.ParamTable(Param2, {Param1: [1, 1]}, index=[4, 5]) pl.set_param_table(grouping1) pl.set_param_table(grouping2) - assert pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) == { - 1: {4: {0: 7}, 5: {1: 8, 2: 9}} - } + assert pl.compute( + sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]] + ) == sl.Series( + Param1, + { + 1: sl.Series( + Param2, {4: sl.Series(Row, {0: 7}), 5: sl.Series(Row, {1: 8, 2: 9})} + ) + }, + ) def test_multi_level_groupby_with_params_from_different_table_preserves_index() -> None: @@ -329,10 +352,18 @@ def test_multi_level_groupby_with_params_from_different_table_preserves_index() grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=[4, 5, 6]) pl.set_param_table(grouping1) pl.set_param_table(grouping2) - assert pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) == { - 1: {4: {100: 7}, 5: {200: 8, 300: 9}}, - 3: {6: {400: 10}}, - } + assert pl.compute( + sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]] + ) == sl.Series( + Param1, + { + 1: sl.Series( + Param2, + {4: sl.Series(Row, {100: 7}), 5: sl.Series(Row, {200: 8, 300: 9})}, + ), + 3: sl.Series(Param2, {6: sl.Series(Row, {400: 10})}), + }, + ) def test_multi_level_groupby_with_params_from_different_table_can_reorder() -> None: @@ -348,10 +379,18 @@ def test_multi_level_groupby_with_params_from_different_table_can_reorder() -> N grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=[6, 5, 4]) pl.set_param_table(grouping1) pl.set_param_table(grouping2) - assert pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) == { - 1: {6: {400: 10}, 5: {200: 8, 300: 9}}, - 3: {4: {100: 7}}, - } + assert pl.compute( + sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]] + ) == sl.Series( + Param1, + { + 1: sl.Series( + Param2, + {6: sl.Series(Row, {400: 10}), 5: sl.Series(Row, {200: 8, 300: 9})}, + ), + 3: sl.Series(Param2, {4: sl.Series(Row, {100: 7})}), + }, + ) @pytest.mark.parametrize("index", [None, [4, 5, 7]]) @@ -397,7 +436,7 @@ def process(x: SummedGroup) -> ProcessedGroup: pl.set_param_table(params) graph = pl.get(sl.Series[Name, ProcessedGroup]) - assert graph.compute() == {'a': 10, 'b': 8} + assert graph.compute() == sl.Series(Name, {'a': 10, 'b': 8}) def test_requesting_series_index_that_is_not_in_param_table_raises() -> None: @@ -442,11 +481,14 @@ def make_float() -> float: assert pipeline.compute(Str[float]) == Str[float]('1.5') with pytest.raises(sl.UnsatisfiedRequirement): pipeline.compute(Str[int]) - assert pipeline.compute(sl.Series[Row, Str[int]]) == { - 0: Str[int]('1'), - 1: Str[int]('2'), - 2: Str[int]('3'), - } + assert pipeline.compute(sl.Series[Row, Str[int]]) == sl.Series( + Row, + { + 0: Str[int]('1'), + 1: Str[int]('2'), + 2: Str[int]('3'), + }, + ) def test_generic_provider_can_depend_on_param_series() -> None: From 9be394c6d00bc615b8c8b294a62bbfdaf3a485ba Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 7 Aug 2023 15:08:15 +0200 Subject: [PATCH 52/87] Fix most mypy --- src/sciline/pipeline.py | 46 ++++++++++++------------------- src/sciline/scheduler.py | 16 ++++------- src/sciline/task_graph.py | 16 ++++++++--- src/sciline/typing.py | 26 +++++++++++++++++ src/sciline/visualize.py | 30 ++++++++++++++------ tests/pipeline_with_index_test.py | 26 +++++++++-------- tests/task_graph_test.py | 2 +- tests/variadic_workflow_test.py | 14 ++-------- 8 files changed, 100 insertions(+), 76 deletions(-) create mode 100644 src/sciline/typing.py diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 8eae63f1..fdbeb5fc 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -3,7 +3,6 @@ from __future__ import annotations from collections import defaultdict -from dataclasses import dataclass from typing import ( Any, Callable, @@ -32,6 +31,7 @@ from .param_table import ParamTable from .scheduler import Scheduler from .series import Series +from .typing import Graph, Item, Key, Label T = TypeVar('T') KeyType = TypeVar('KeyType') @@ -54,19 +54,7 @@ class AmbiguousProvider(Exception): """Raised when multiple providers are found for a type.""" -@dataclass(frozen=True) -class Label: - tp: type - index: int - - -@dataclass(frozen=True) -class Item(Generic[T]): - label: Tuple[Label, ...] - tp: Type[T] - - -def _indexed_key(index_name: Any, i: int, value_name: Any) -> Item: +def _indexed_key(index_name: Any, i: int, value_name: Type[T] | Item[T]) -> Item[T]: label = Label(index_name, i) if isinstance(value_name, Item): return Item(value_name.label + (label,), value_name.tp) @@ -75,11 +63,6 @@ def _indexed_key(index_name: Any, i: int, value_name: Any) -> Item: Provider = Callable[..., Any] -Key = Union[type, Item] -Graph = Dict[ - Key, - Tuple[Callable[..., Any], Tuple[Key, ...]], -] def _is_compatible_type_tuple( @@ -166,7 +149,7 @@ def get_grouping(self, key: Any, group: LabelType) -> Optional[List[IndexType]]: return None return self._index[group] - def in_group(self, arg: Item, group: LabelType) -> bool: + def in_group(self, arg: Item[Any], group: LabelType) -> bool: if len(arg.label) != 1: raise ValueError(f'Cannot group with multi-index label {arg.label}') return arg.label[0].index in self._index[group] @@ -192,11 +175,13 @@ def restrict( return SeriesProducer(labels, self._row_dim) +class _param_sentinel: + ... + + class Pipeline: """A container for providers that can be assembled into a task graph.""" - _param_sentinel = object() - def __init__( self, providers: Optional[List[Provider]] = None, @@ -292,7 +277,7 @@ def set_param_table(self, params: ParamTable) -> None: ) def _set_provider( - self, key: Union[Type[T], Item], provider: Callable[..., T] + self, key: Union[Type[T], Item[T]], provider: Callable[..., T] ) -> None: # isinstance does not work here and types.NoneType available only in 3.10+ if key == type(None): # noqa: E721 @@ -309,7 +294,7 @@ def _set_provider( self._providers[key] = provider def _get_provider( - self, tp: Union[Type[T], Item] + self, tp: Union[Type[T], Item[T]] ) -> Tuple[Callable[..., T], Dict[TypeVar, Key]]: if (provider := self._providers.get(tp)) is not None: return provider, {} @@ -335,7 +320,7 @@ def _get_provider( raise UnsatisfiedRequirement("No provider found for type", tp) def build( - self, tp: Union[Type[T], Item], /, search_param_tables: bool = False + self, tp: Union[Type[T], Item[T]], /, search_param_tables: bool = False ) -> Graph: """ Return a dict of providers required for building the requested type `tp`. @@ -352,11 +337,11 @@ def build( Type to build the graph for. """ graph: Graph = {} - stack: List[Union[Type[T], Item]] = [tp] + stack: List[Union[Type[T], Item[T]]] = [tp] while stack: tp = stack.pop() if search_param_tables and tp in self._param_series: - graph[tp] = (self._param_sentinel, (self._param_series[tp],)) + graph[tp] = (_param_sentinel, (self._param_series[tp],)) continue if get_origin(tp) == Series: graph.update(self._build_series(tp)) # type: ignore[arg-type] @@ -410,7 +395,7 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: for index in grouper: provider, args = value subkey = _indexed_key(label_name, index, key) - if provider == self._param_sentinel: + if provider == _param_sentinel: provider, _ = self._get_provider(subkey) args = () args_with_index = tuple( @@ -477,7 +462,10 @@ def visualize( return self.get(tp).visualize(**kwargs) def get( - self, keys: type | Tuple[type, ...], *, scheduler: Optional[Scheduler] = None + self, + keys: type | Tuple[type, ...] | Item[T], + *, + scheduler: Optional[Scheduler] = None, ) -> TaskGraph: """ Return a TaskGraph for the given keys. diff --git a/src/sciline/scheduler.py b/src/sciline/scheduler.py index 963e1e8a..49213c84 100644 --- a/src/sciline/scheduler.py +++ b/src/sciline/scheduler.py @@ -1,12 +1,8 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple +from typing import Any, Callable, Dict, List, Optional, Protocol -Key = type -Graph = Dict[ - Key, - Tuple[Callable[..., Any], Tuple[Key, ...]], -] +from sciline.typing import Graph, Key class CycleError(Exception): @@ -18,7 +14,7 @@ class Scheduler(Protocol): Scheduler interface compatible with :py:class:`sciline.Pipeline`. """ - def get(self, graph: Graph, keys: List[type]) -> Any: + def get(self, graph: Graph, keys: List[Key]) -> Any: """ Compute the result for given keys from the graph. @@ -37,7 +33,7 @@ class NaiveScheduler: :py:class:`DaskScheduler` instead. """ - def get(self, graph: Graph, keys: List[type]) -> Any: + def get(self, graph: Graph, keys: List[Key]) -> Any: import graphlib dependencies = {tp: args for tp, (_, args) in graph.items()} @@ -47,7 +43,7 @@ def get(self, graph: Graph, keys: List[type]) -> Any: tasks = list(ts.static_order()) except graphlib.CycleError as e: raise CycleError from e - results: Dict[type, Any] = {} + results: Dict[Key, Any] = {} for t in tasks: provider, args = graph[t] results[t] = provider(*[results[arg] for arg in args]) @@ -76,7 +72,7 @@ def __init__(self, scheduler: Optional[Callable[..., Any]] = None) -> None: else: self._dask_get = scheduler - def get(self, graph: Graph, keys: List[type]) -> Any: + def get(self, graph: Graph, keys: List[Key]) -> Any: dsk = {tp: (provider, *args) for tp, (provider, args) in graph.items()} try: return self._dask_get(dsk, keys) diff --git a/src/sciline/task_graph.py b/src/sciline/task_graph.py index e832fcda..cebb12cb 100644 --- a/src/sciline/task_graph.py +++ b/src/sciline/task_graph.py @@ -2,9 +2,12 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from __future__ import annotations -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple, TypeVar, Union -from sciline.scheduler import DaskScheduler, Graph, NaiveScheduler, Scheduler +from .scheduler import DaskScheduler, NaiveScheduler, Scheduler +from .typing import Graph, Item + +T = TypeVar("T") class TaskGraph: @@ -19,7 +22,7 @@ def __init__( self, *, graph: Graph, - keys: Union[type, Tuple[type, ...]], + keys: Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]], scheduler: Optional[Scheduler] = None, ) -> None: self._graph = graph @@ -31,7 +34,12 @@ def __init__( scheduler = NaiveScheduler() self._scheduler = scheduler - def compute(self, keys: Optional[Union[type, Tuple[type, ...]]] = None) -> Any: + def compute( + self, + keys: Optional[ + Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]] + ] = None, + ) -> Any: """ Compute the result of the graph. diff --git a/src/sciline/typing.py b/src/sciline/typing.py new file mode 100644 index 00000000..26ff8ba7 --- /dev/null +++ b/src/sciline/typing.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generic, Tuple, Type, TypeVar, Union + + +@dataclass(frozen=True) +class Label: + tp: type + index: int + + +T = TypeVar('T') + + +@dataclass(frozen=True) +class Item(Generic[T]): + label: Tuple[Label, ...] + tp: Type[T] + + +Key = Union[type, Item] +Graph = Dict[ + Key, + Tuple[Callable[..., Any], Tuple[Key, ...]], +] diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index 699776b5..199cbc05 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -1,11 +1,21 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import Any, Callable, Dict, List, Tuple, Union, get_args, get_origin +from typing import ( + Any, + Callable, + Dict, + List, + Tuple, + Type, + TypeVar, + Union, + get_args, + get_origin, +) from graphviz import Digraph -from .pipeline import Item, Label -from .scheduler import Graph +from .typing import Graph, Item, Key def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph: @@ -61,21 +71,23 @@ def _format_graph(graph: Graph) -> Dict[str, Tuple[str, List[str], str]]: } -def _format_provider(provider: Callable[..., Any], ret: type) -> str: +def _format_provider(provider: Callable[..., Any], ret: Key) -> str: return f'{_qualname(provider)}_{_format_type(ret)}' -def _extract_type_and_labels(key: Union[Item, Label, type]) -> Tuple[type, List[type]]: +T = TypeVar('T') + + +def _extract_type_and_labels( + key: Union[Item[T], Type[T]] +) -> Tuple[Type[T], List[type]]: if isinstance(key, Item): label = key.label return key.tp, [lb.tp for lb in label] - if isinstance(key, Label): - tp, labels = _extract_type_and_labels(key.tp) - return tp, [key.tp] + labels return key, [] -def _format_type(tp: type) -> str: +def _format_type(tp: Key) -> str: """ Helper for _format_graph. diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index a7313212..70f9ea61 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import NewType, TypeVar +from typing import List, NewType, Optional, TypeVar import pytest import sciline as sl -from sciline.pipeline import Item, Label +from sciline.typing import Item, Label def test_set_param_table_raises_if_param_names_are_duplicate() -> None: @@ -94,7 +94,7 @@ def make_float(x: str) -> float: def test_can_zip() -> None: - Sum = NewType("Sum", float) + Sum = NewType("Sum", str) Str = NewType("Str", str) Run = NewType("Run", int) @@ -115,7 +115,7 @@ def test_diamond_dependency_pulls_values_from_columns_in_same_param_table() -> N Sum = NewType("Sum", float) Param1 = NewType("Param1", int) Param2 = NewType("Param2", int) - Row = NewType("Run", int) + Row = NewType("Row", int) def gather(x: sl.Series[Row, float]) -> Sum: return Sum(sum(x.values())) @@ -126,7 +126,7 @@ def join(x: Param1, y: Param2) -> float: pl = sl.Pipeline([gather, join]) pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 4, 9], Param2: [1, 2, 3]})) - assert pl.compute(Sum) == 6 + assert pl.compute(Sum) == Sum(6) def test_diamond_dependency_on_same_column() -> None: @@ -134,7 +134,7 @@ def test_diamond_dependency_on_same_column() -> None: Param = NewType("Param", int) Param1 = NewType("Param1", int) Param2 = NewType("Param2", int) - Row = NewType("Run", int) + Row = NewType("Row", int) def gather(x: sl.Series[Row, float]) -> Sum: return Sum(sum(x.values())) @@ -151,7 +151,7 @@ def join(x: Param1, y: Param2) -> float: pl = sl.Pipeline([gather, join, to_param1, to_param2]) pl.set_param_table(sl.ParamTable(Row, {Param: [1, 2, 3]})) - assert pl.compute(Sum) == 3 + assert pl.compute(Sum) == Sum(3) def test_dependencies_on_different_param_tables_broadcast() -> None: @@ -163,7 +163,7 @@ def test_dependencies_on_different_param_tables_broadcast() -> None: def gather_both(x: sl.Series[Row1, Param1], y: sl.Series[Row2, Param2]) -> Product: broadcast = [[x_, y_] for x_ in x.values() for y_ in y.values()] - return str(broadcast) + return Product(str(broadcast)) pl = sl.Pipeline([gather_both]) pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) @@ -183,7 +183,7 @@ def gather2_and_combine(x: Param1, y: sl.Series[Row2, Param2]) -> Summed2: return Summed2(x * sum(y.values())) def gather1(x: sl.Series[Row1, Summed2]) -> Product: - return str(list(x.values())) + return Product(str(list(x.values()))) pl = sl.Pipeline([gather1, gather2_and_combine]) pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) @@ -207,7 +207,7 @@ def combine(x: Param1, y: Summed2) -> Combined: return Combined(x * y) def gather1(x: sl.Series[Row1, Combined]) -> Product: - return str(list(x.values())) + return Product(str(list(x.values()))) pl = sl.Pipeline([gather1, gather2, combine]) pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) @@ -394,7 +394,9 @@ def test_multi_level_groupby_with_params_from_different_table_can_reorder() -> N @pytest.mark.parametrize("index", [None, [4, 5, 7]]) -def test_multi_level_groupby_raises_on_index_mismatch(index) -> None: +def test_multi_level_groupby_raises_on_index_mismatch( + index: Optional[List[int]], +) -> None: Row = NewType("Row", int) Param1 = NewType("Param1", int) Param2 = NewType("Param2", int) @@ -412,7 +414,7 @@ def test_multi_level_groupby_raises_on_index_mismatch(index) -> None: @pytest.mark.parametrize("index", [None, [4, 5, 6]]) -def test_groupby_over_param_table(index) -> None: +def test_groupby_over_param_table(index: Optional[List[int]]) -> None: Index = NewType("Index", int) Name = NewType("Name", str) Param = NewType("Param", int) diff --git a/tests/task_graph_test.py b/tests/task_graph_test.py index 0f376049..31e24714 100644 --- a/tests/task_graph_test.py +++ b/tests/task_graph_test.py @@ -3,8 +3,8 @@ import pytest import sciline as sl -from sciline.scheduler import Graph from sciline.task_graph import TaskGraph +from sciline.typing import Graph def as_float(x: int) -> float: diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py index 1e1acb97..98341e93 100644 --- a/tests/variadic_workflow_test.py +++ b/tests/variadic_workflow_test.py @@ -23,19 +23,11 @@ def read(filename: Filename) -> Image: def image_param(filename: Filename) -> ImageParam: return ImageParam(sum(ord(c) for c in filename)) - def clean2(x: Image, param: ImageParam) -> CleanedImage: - return x * param - def clean(x: Image) -> CleanedImage: - return x + return CleanedImage(x) def scale(x: CleanedImage, param: Param, config: Config) -> ScaledImage: - return x * param + config - - def combine_old( - images: sl.Series[Run, ScaledImage], params: sl.Series[Run, ImageParam] - ) -> float: - return sum(images.values()) + return ScaledImage(x * param + config) def combine(images: sl.Series[Run, ScaledImage]) -> float: return sum(images.values()) @@ -47,7 +39,7 @@ def make_int() -> int: return 2 def make_param() -> Param: - return 2.0 + return Param(2.0) filenames = tuple(f'file{i}' for i in range(6)) configs = tuple(range(2)) From a0772d38d3ab7de62ea92b30ee7c8f1ba7303921 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 06:44:01 +0200 Subject: [PATCH 53/87] Fix mypy --- src/sciline/param_table.py | 5 ++--- src/sciline/pipeline.py | 16 +++++++++------- src/sciline/series.py | 3 +-- src/sciline/typing.py | 8 ++++---- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/sciline/param_table.py b/src/sciline/param_table.py index b800fa43..74298cfb 100644 --- a/src/sciline/param_table.py +++ b/src/sciline/param_table.py @@ -2,11 +2,10 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from __future__ import annotations -from collections import abc -from typing import Any, Collection, Dict, Optional +from typing import Any, Collection, Dict, Mapping, Optional -class ParamTable(abc.Mapping[type, Collection[Any]]): +class ParamTable(Mapping[type, Collection[Any]]): def __init__( self, row_dim: type, diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index fdbeb5fc..1e684371 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -31,7 +31,7 @@ from .param_table import ParamTable from .scheduler import Scheduler from .series import Series -from .typing import Graph, Item, Key, Label +from .typing import Graph, Item, Key, Label, Provider T = TypeVar('T') KeyType = TypeVar('KeyType') @@ -62,9 +62,6 @@ def _indexed_key(index_name: Any, i: int, value_name: Type[T] | Item[T]) -> Item return Item((label,), value_name) -Provider = Callable[..., Any] - - def _is_compatible_type_tuple( requested: tuple[Key, ...], provided: tuple[Key | TypeVar, ...], @@ -404,17 +401,22 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: if in_group(arg, index) ) if isinstance(provider, SeriesProducer): - provider = provider.restrict(grouper.get_grouping(key, index)) + # For some reason mypy does not detect that SeriesProducer is + # Callable? + provider = provider.restrict( # type: ignore[unreachable] + grouper.get_grouping(key, index) + ) graph[subkey] = (provider, args_with_index) else: graph[key] = value return graph def _find_grouping_node(self, index_name: Key, subgraph: Graph) -> type: - ends = [] + ends: List[type] = [] for key in subgraph: if get_origin(key) == Series and get_args(key)[0] == index_name: - ends.append(key) + # Because if the succeeded get_origin we know it is a type + ends.append(key) # type: ignore[arg-type] if len(ends) == 1: return ends[0] raise ValueError(f"Could not find unique grouping node, found {ends}") diff --git a/src/sciline/series.py b/src/sciline/series.py index c88f1e70..4c6c9846 100644 --- a/src/sciline/series.py +++ b/src/sciline/series.py @@ -2,14 +2,13 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from __future__ import annotations -from collections import abc from typing import Iterator, Mapping, Type, TypeVar Key = TypeVar('Key') Value = TypeVar('Value') -class Series(abc.Mapping[Key, Value]): +class Series(Mapping[Key, Value]): def __init__(self, row_dim: Type[Key], values: Mapping[Key, Value]) -> None: self._row_dim = row_dim self._map: Mapping[Key, Value] = values diff --git a/src/sciline/typing.py b/src/sciline/typing.py index 26ff8ba7..6ecd2729 100644 --- a/src/sciline/typing.py +++ b/src/sciline/typing.py @@ -19,8 +19,8 @@ class Item(Generic[T]): tp: Type[T] +Provider = Callable[..., Any] + + Key = Union[type, Item] -Graph = Dict[ - Key, - Tuple[Callable[..., Any], Tuple[Key, ...]], -] +Graph = Dict[Key, Tuple[Provider, Tuple[Key, ...]]] From cd62fe916e5e16da1e1361b7fbf430ec247d03ff Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 06:50:39 +0200 Subject: [PATCH 54/87] Remove submodule with non-public functionality --- src/sciline/graph.py | 28 ---------------------------- src/sciline/pipeline.py | 34 +++++++++++++++++++++++++++++++--- tests/graph_test.py | 8 ++++---- 3 files changed, 35 insertions(+), 35 deletions(-) delete mode 100644 src/sciline/graph.py diff --git a/src/sciline/graph.py b/src/sciline/graph.py deleted file mode 100644 index 2c03b866..00000000 --- a/src/sciline/graph.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Any, Callable, Collection, List, Mapping, Tuple, TypeVar - -T = TypeVar("T") - - -def find_all_paths(graph: Mapping[T, Collection[T]], start: T, end: T) -> List[List[T]]: - """Find all paths from start to end in a DAG.""" - if start == end: - return [[start]] - if start not in graph: - return [] - paths = [] - for node in graph[start]: - for path in find_all_paths(graph, node, end): - paths.append([start] + path) - return paths - - -def find_nodes_in_paths( - graph: Mapping[T, Tuple[Callable[..., Any], Collection[T]]], start: T, end: T -) -> List[T]: - # 0 is the provider, 1 is the args - dependencies = {k: v[1] for k, v in graph.items()} - paths = find_all_paths(dependencies, start, end) - nodes = set() - for path in paths: - nodes.update(path) - return list(nodes) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 1e684371..3360a934 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -6,12 +6,14 @@ from typing import ( Any, Callable, + Collection, Dict, Generic, Iterable, Iterator, List, Literal, + Mapping, Optional, Protocol, Tuple, @@ -27,7 +29,6 @@ from sciline.task_graph import TaskGraph from .domain import Scope -from .graph import find_nodes_in_paths from .param_table import ParamTable from .scheduler import Scheduler from .series import Series @@ -93,6 +94,33 @@ def _bind_free_typevars(tp: TypeVar | Key, bound: Dict[TypeVar, Key]) -> Key: return tp +def _find_all_paths( + graph: Mapping[T, Collection[T]], start: T, end: T +) -> List[List[T]]: + """Find all paths from start to end in a DAG.""" + if start == end: + return [[start]] + if start not in graph: + return [] + paths = [] + for node in graph[start]: + for path in _find_all_paths(graph, node, end): + paths.append([start] + path) + return paths + + +def _find_nodes_in_paths( + graph: Mapping[T, Tuple[Callable[..., Any], Collection[T]]], start: T, end: T +) -> List[T]: + # 0 is the provider, 1 is the args + dependencies = {k: v[1] for k, v in graph.items()} + paths = _find_all_paths(dependencies, start, end) + nodes = set() + for path in paths: + nodes.update(path) + return list(nodes) + + def _yes(*_: Any) -> Literal[True]: return True @@ -367,13 +395,13 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: label_name not in self._param_series and (params := self._param_tables.get(label_name)) is not None ): - path = find_nodes_in_paths(subgraph, value_type, label_name) + path = _find_nodes_in_paths(subgraph, value_type, label_name) grouper = NoGrouping(index=params.index) elif (index_name := self._param_series.get(label_name)) is not None: params = self._param_tables[index_name] labels = params[label_name] grouping_node = self._find_grouping_node(index_name, subgraph) - path = find_nodes_in_paths(subgraph, value_type, grouping_node) + path = _find_nodes_in_paths(subgraph, value_type, grouping_node) grouper = GroupBy( index=params.index, labels=labels, grouping_node=grouping_node ) diff --git a/tests/graph_test.py b/tests/graph_test.py index 4d78a7f7..65705b20 100644 --- a/tests/graph_test.py +++ b/tests/graph_test.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from sciline.graph import find_all_paths +from sciline.pipeline import _find_all_paths def test_find_all_paths() -> None: graph = {"D": ["B", "C"], "C": ["A"], "B": ["A"]} - assert find_all_paths(graph, "D", "A") == [["D", "B", "A"], ["D", "C", "A"]] - assert find_all_paths(graph, "B", "C") == [] - assert find_all_paths(graph, "B", "A") == [["B", "A"]] + assert _find_all_paths(graph, "D", "A") == [["D", "B", "A"], ["D", "C", "A"]] + assert _find_all_paths(graph, "B", "C") == [] + assert _find_all_paths(graph, "B", "A") == [["B", "A"]] From 5f6a560dc6eba66b9e77372fcb26248a0b179e66 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 07:04:06 +0200 Subject: [PATCH 55/87] Docstrings --- src/sciline/param_table.py | 20 ++++++++++++++++++++ src/sciline/series.py | 17 +++++++++++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/src/sciline/param_table.py b/src/sciline/param_table.py index 74298cfb..69d80766 100644 --- a/src/sciline/param_table.py +++ b/src/sciline/param_table.py @@ -6,6 +6,8 @@ class ParamTable(Mapping[type, Collection[Any]]): + """A table of parameters with a row index and named row dimension.""" + def __init__( self, row_dim: type, @@ -13,6 +15,22 @@ def __init__( *, index: Optional[Collection[Any]] = None, ): + """ + Create a new param table. + + Parameters + ---------- + row_dim: + The row dimension. This must be a type or a type-alias (not an instance), + and is used by :py:class:`sciline.Pipeline` to identify each parameter + table. + columns: + The columns of the table. The keys (column names) must be types or type- + aliases matching the values in the respective columns. + index: + The row index of the table. If not given, a default index will be + generated, as the integer range of the column length. + """ sizes = set(len(v) for v in columns.values()) if len(sizes) != 1: raise ValueError( @@ -32,10 +50,12 @@ def __init__( @property def row_dim(self) -> type: + """The row dimension of the table.""" return self._row_dim @property def index(self) -> Collection[Any]: + """The row index of the table.""" return self._index def __contains__(self, key: Any) -> bool: diff --git a/src/sciline/series.py b/src/sciline/series.py index 4c6c9846..fb82ca0b 100644 --- a/src/sciline/series.py +++ b/src/sciline/series.py @@ -9,12 +9,25 @@ class Series(Mapping[Key, Value]): - def __init__(self, row_dim: Type[Key], values: Mapping[Key, Value]) -> None: + """A series of values with labels (row index) and named row dimension.""" + + def __init__(self, row_dim: Type[Key], items: Mapping[Key, Value]) -> None: + """ + Create a new series. + + Parameters + ---------- + row_dim: + The row dimension. This must be a type or a type-alias (not an instance). + items: + The items of the series. + """ self._row_dim = row_dim - self._map: Mapping[Key, Value] = values + self._map: Mapping[Key, Value] = items @property def row_dim(self) -> type: + """The row dimension of the series.""" return self._row_dim def __contains__(self, item: object) -> bool: From 7aa2720ab6c55f78c78ce592fc01bf1c06b334f3 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 07:22:58 +0200 Subject: [PATCH 56/87] Update docs --- docs/api-reference/index.md | 2 ++ docs/user-guide/index.md | 1 + src/sciline/pipeline.py | 39 ++++++++++++++++++++++++++++++++----- 3 files changed, 37 insertions(+), 5 deletions(-) diff --git a/docs/api-reference/index.md b/docs/api-reference/index.md index 883ecd98..0cba8d9c 100644 --- a/docs/api-reference/index.md +++ b/docs/api-reference/index.md @@ -10,8 +10,10 @@ :template: class-template.rst :recursive: + ParamTable Pipeline Scope + Series scheduler.Scheduler scheduler.DaskScheduler scheduler.NaiveScheduler diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md index ca13f1a2..3baa2628 100644 --- a/docs/user-guide/index.md +++ b/docs/user-guide/index.md @@ -6,4 +6,5 @@ maxdepth: 2 --- getting-started +parameter-tables ``` diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 3360a934..8c73fb6e 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -112,6 +112,10 @@ def _find_all_paths( def _find_nodes_in_paths( graph: Mapping[T, Tuple[Callable[..., Any], Collection[T]]], start: T, end: T ) -> List[T]: + """ + Find all nodes that need to be duplicated since they depend on a value from a + param table. + """ # 0 is the provider, 1 is the args dependencies = {k: v[1] for k, v in graph.items()} paths = _find_all_paths(dependencies, start, end) @@ -126,6 +130,8 @@ def _yes(*_: Any) -> Literal[True]: class Grouper(Protocol): + """Helper protocol for rewriting graphs.""" + def __iter__(self) -> Iterator[Any]: ... @@ -137,6 +143,8 @@ def get_grouping(self, key: Any, group: Any) -> Any: class NoGrouping(Generic[IndexType]): + """Helper for rewriting the graph to map over a given index.""" + def __init__(self, index: Iterable[IndexType]) -> None: self._index = index @@ -151,6 +159,8 @@ def get_grouping(self, key: Any, group: int) -> None: class GroupBy(Generic[IndexType, LabelType]): + """Helper for rewriting the graph to group by a given index.""" + def __init__( self, *, @@ -180,7 +190,12 @@ def in_group(self, arg: Item[Any], group: LabelType) -> bool: return arg.label[0].index in self._index[group] -class SeriesProducer(Generic[KeyType, ValueType]): +class SeriesProvider(Generic[KeyType, ValueType]): + """ + Internal provider for combining results obtained based on different rows in a + param table into a single object. + """ + def __init__(self, labels: Iterable[KeyType], row_dim: type) -> None: self._labels = labels self._row_dim = row_dim @@ -190,14 +205,14 @@ def __call__(self, *vals: ValueType) -> Series[KeyType, ValueType]: def restrict( self, labels: Optional[Iterable[KeyType]] - ) -> SeriesProducer[KeyType, ValueType]: + ) -> SeriesProvider[KeyType, ValueType]: if labels is None: return self if set(labels) - set(self._labels): raise ValueError(f'{labels} is not a subset of {self._labels}') # Ensure that labels are in the same order as in the original series labels = [label for label in self._labels if label in labels] - return SeriesProducer(labels, self._row_dim) + return SeriesProvider(labels, self._row_dim) class _param_sentinel: @@ -286,6 +301,20 @@ def __setitem__(self, key: Type[T], param: T) -> None: self._set_provider(key, lambda: param) def set_param_table(self, params: ParamTable) -> None: + """ + Set a parameter table for a row dimension. + + Values in the parameter table provide concrete values for a type given by the + respective column header. + + A pipeline can have multiple parameter tables, but only one per row dimension. + Column names must be unique across all parameter tables. + + Parameters + ---------- + params: + Parameter table to set. + """ if params.row_dim in self._param_tables: raise ValueError(f'Parameter table for {params.row_dim} already set') for param_name in params: @@ -410,7 +439,7 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: graph: Graph = {} graph[tp] = ( - SeriesProducer(list(grouper), label_name), + SeriesProvider(list(grouper), label_name), tuple(_indexed_key(label_name, index, value_type) for index in grouper), ) @@ -428,7 +457,7 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: for arg in args if in_group(arg, index) ) - if isinstance(provider, SeriesProducer): + if isinstance(provider, SeriesProvider): # For some reason mypy does not detect that SeriesProducer is # Callable? provider = provider.restrict( # type: ignore[unreachable] From baa1c40d077409d3d09decd73925e22d949632a3 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 07:32:25 +0200 Subject: [PATCH 57/87] Avoid hard-coding names of objects --- src/sciline/visualize.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index 199cbc05..f79f57ca 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -15,6 +15,7 @@ from graphviz import Digraph +from .pipeline import Pipeline, SeriesProvider from .typing import Graph, Item, Key @@ -37,13 +38,13 @@ def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph: dot.node(ret, ret, shape='rectangle') # Do not draw dummy providers created by Pipeline when setting instances if p_name in ( - 'Pipeline.__setitem__..', - 'Pipeline.set_param_table..', + f'{_qualname(Pipeline.__setitem__)}..', + f'{_qualname(Pipeline.set_param_table)}..', ): continue # Do not draw the internal provider gathering index-dependent results into # a dict - if p_name == 'SeriesProducer': + if p_name == _qualname(SeriesProvider): for arg in args: dot.edge(arg, ret, style='dashed') else: From 10e2bcc189e90e99ddb0fb5cc207fcf1937246f3 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 07:35:35 +0200 Subject: [PATCH 58/87] Remove old test file --- tests/variadic_workflow_test.py | 68 --------------------------------- 1 file changed, 68 deletions(-) delete mode 100644 tests/variadic_workflow_test.py diff --git a/tests/variadic_workflow_test.py b/tests/variadic_workflow_test.py deleted file mode 100644 index 98341e93..00000000 --- a/tests/variadic_workflow_test.py +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import NewType - -import sciline as sl - - -def test_Map() -> None: - Run = NewType('Run', int) - Setting = NewType('Setting', int) - Filename = NewType('Filename', str) - Config = NewType('Config', int) - Image = NewType('Image', float) - CleanedImage = NewType('CleanedImage', float) - ScaledImage = NewType('ScaledImage', float) - Param = NewType('Param', float) - ImageParam = NewType('ImageParam', float) - Result = NewType('Result', float) - - def read(filename: Filename) -> Image: - return Image(float(filename[-1])) - - def image_param(filename: Filename) -> ImageParam: - return ImageParam(sum(ord(c) for c in filename)) - - def clean(x: Image) -> CleanedImage: - return CleanedImage(x) - - def scale(x: CleanedImage, param: Param, config: Config) -> ScaledImage: - return ScaledImage(x * param + config) - - def combine(images: sl.Series[Run, ScaledImage]) -> float: - return sum(images.values()) - - def combine_configs(x: sl.Series[Setting, float]) -> Result: - return Result(sum(x.values())) - - def make_int() -> int: - return 2 - - def make_param() -> Param: - return Param(2.0) - - filenames = tuple(f'file{i}' for i in range(6)) - configs = tuple(range(2)) - pipeline = sl.Pipeline( - [ - read, - clean, - scale, - combine, - combine_configs, - make_int, - make_param, - image_param, - ] - ) - pipeline.set_param_table(sl.ParamTable(Run, {Filename: filenames})) - pipeline.set_param_table(sl.ParamTable(Setting, {Config: configs})) - - graph = pipeline.get(int) - assert graph.compute() == 2 - graph = pipeline.get(Result) - assert graph.compute() == 66.0 - # graph.visualize().render('graph', format='png') - # from dask.delayed import Delayed - # dsk = {key: (value, *args) for key, (value, args) in graph._graph.items()} - # Delayed(Result, dsk).visualize(filename='graph.png') From 3a48aa799073b6d5c06d84a7f3d1b09e921663c7 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 08:02:59 +0200 Subject: [PATCH 59/87] Comment complicated algorithm --- src/sciline/pipeline.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 8c73fb6e..f5e33040 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -389,6 +389,8 @@ def build( ---------- tp: Type to build the graph for. + search_param_tables: + Whether to search parameter tables for concrete keys. """ graph: Graph = {} stack: List[Union[Type[T], Item[T]]] = [tp] @@ -418,7 +420,20 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: label_name: Type[KeyType] value_type: Type[ValueType] label_name, value_type = get_args(tp) + # Step 1: + # Build a graph that can compute the value type. As we are building + # a Series, this will terminate when it reaches a parameter that is not a + # single provided value but a collection of values from a parameter table + # column. Instead of single value (which does not exist), a sentinel is + # used to mark this, for processing below. subgraph = self.build(value_type, search_param_tables=True) + # Step 2: + # Identify nodes in the graph that need to be duplicated as they lie in the + # path to a parameter from a table. In the case of grouping, note that the + # ungrouped graph (including duplicate of nodes) will have been built by a + # prior call to _build_series, so instead of duplicated everything until the + # param table is reached, we only duplicate until the node that is performing + # the grouping. grouper: Grouper if ( label_name not in self._param_series @@ -443,6 +458,8 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: tuple(_indexed_key(label_name, index, value_type) for index in grouper), ) + # Step 3: + # Duplicate nodes, replacing keys with indexed keys. for key, value in subgraph.items(): if key in path: in_group = grouper(key) From 678ef5b475aecf99e2f5a777fe8bc7d6f8befc02 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 08:11:05 +0200 Subject: [PATCH 60/87] Spelling --- src/sciline/pipeline.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index f5e33040..22da57d9 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -475,8 +475,7 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: if in_group(arg, index) ) if isinstance(provider, SeriesProvider): - # For some reason mypy does not detect that SeriesProducer is - # Callable? + # mypy does not detect that SeriesProducer is Callable? provider = provider.restrict( # type: ignore[unreachable] grouper.get_grouping(key, index) ) @@ -489,7 +488,7 @@ def _find_grouping_node(self, index_name: Key, subgraph: Graph) -> type: ends: List[type] = [] for key in subgraph: if get_origin(key) == Series and get_args(key)[0] == index_name: - # Because if the succeeded get_origin we know it is a type + # Because of the succeeded get_origin we know it is a type ends.append(key) # type: ignore[arg-type] if len(ends) == 1: return ends[0] From d9127db0f6cae27ab0a33086c5d172552a7a6758 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 08:12:49 +0200 Subject: [PATCH 61/87] Grammar --- src/sciline/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 22da57d9..acd862e5 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -430,7 +430,7 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: # Step 2: # Identify nodes in the graph that need to be duplicated as they lie in the # path to a parameter from a table. In the case of grouping, note that the - # ungrouped graph (including duplicate of nodes) will have been built by a + # ungrouped graph (including duplication of nodes) will have been built by a # prior call to _build_series, so instead of duplicated everything until the # param table is reached, we only duplicate until the node that is performing # the grouping. From 2aaebf16d86bccad1e652d27574fc4c8a71959ab Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 09:30:31 +0200 Subject: [PATCH 62/87] Improve readability, default to non-compact graph formatting --- docs/user-guide/parameter-tables.ipynb | 86 +++++++++++++++++++------- src/sciline/param_table.py | 12 ++++ src/sciline/visualize.py | 35 +++++++---- 3 files changed, 97 insertions(+), 36 deletions(-) diff --git a/docs/user-guide/parameter-tables.ipynb b/docs/user-guide/parameter-tables.ipynb index 09566cdd..53852a1c 100644 --- a/docs/user-guide/parameter-tables.ipynb +++ b/docs/user-guide/parameter-tables.ipynb @@ -38,6 +38,7 @@ " 'file102.txt': [1, 2, float('nan'), 3],\n", " 'file103.txt': [1, 2, 3, 4],\n", " 'file104.txt': [1, 2, 3, 4, 5],\n", + " 'file105.txt': [1, 2, 3],\n", "}\n", "\n", "# 1. Define domain types\n", @@ -79,13 +80,9 @@ "\n", "# 3.b Parameter table\n", "RunID = NewType('RunID', int)\n", - "run_ids = [102, 103, 104]\n", + "run_ids = [102, 103, 104, 105]\n", "filenames = [f'file{i}.txt' for i in run_ids]\n", - "param_table = sciline.ParamTable(RunID, {Filename: filenames}, index=run_ids)\n", - "\n", - "# 3.c Setup pipeline\n", - "pipeline = sciline.Pipeline(providers, params=params)\n", - "pipeline.set_param_table(param_table)" + "param_table = sciline.ParamTable(RunID, {Filename: filenames}, index=run_ids)" ] }, { @@ -93,7 +90,7 @@ "metadata": {}, "source": [ "Note how steps 1.) and 2.) are identical to those from the example without parameter table.\n", - "We can now compute `Result` for each index in the parameter table:" + "Above we have created the following parameter table:" ] }, { @@ -102,23 +99,52 @@ "metadata": {}, "outputs": [], "source": [ - "pipeline.compute(sciline.Series[RunID, Result])" + "param_table" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "\n", - "`sciline.Series` is a special `dict`-like type that signals to Sciline that the values of the series are based on values from one or more columns of a parameter table.\n", - "The parameter table is identified using the first argument to `Series`, in this case `RunID`.\n", - "The second argument specifies the result to be computed." + "We can now create the pipeline and set the parameter table:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 3.c Setup pipeline\n", + "pipeline = sciline.Pipeline(providers, params=params)\n", + "pipeline.set_param_table(param_table)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we can compute `Result` for each index in the parameter table:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.compute(sciline.Series[RunID, Result])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ + "\n", + "`sciline.Series` is a special `dict`-like type that signals to Sciline that the values of the series are based on values from one or more columns of a parameter table.\n", + "The parameter table is identified using the first argument to `Series`, in this case `RunID`.\n", + "The second argument specifies the result to be computed.\n", + "\n", "We can also visualize the task graph for computing the series of `Result` values:" ] }, @@ -135,15 +161,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Sciline uses a compact representation of the task graph.\n", - "Instead of drawing every intermediate result and provider for each parameter, we represent parameter-dependent results as \"3D box\" nodes, with the parameter index name (the row dimension of the parameter table) given in parenthesis.\n", - "For example, above `Filename(RunID)` and `Result(RunID)` represent series of nodes, each for a different run ID.\n", + "Nodes that depend on values from a parameter table are drawn as \"3D boxes\" with the parameter index name (the row dimension of the parameter table) and value given in parenthesis.\n", + "The dashed arrow indicates and internal transformation that gathers result from each branch and combines them into a single output, here `Series[RunID, Result]`.\n", + "Note how this transitions from a \"3D box\" to a \"2D box\" as the series of keys in the graph is reduced into a single key.\n", "\n", - "Note that for computing the results they are handled independently, i.e., the task graph has independent branches for each run ID.\n", - "This is important for parallelization, as it allows to run the tasks for different run IDs in parallel and avoids excessive memory use for intemediate results.\n", + "
\n", "\n", - "The dashed arrow indicates and internal transformation that gathers result from each branch and combines them into a single output, here `Series[RunID, Result]`.\n", - "Note how this transitions from a \"3D box\" to a \"2D box\" as the series of keys in the graph is reduced into a single key." + "Note\n", + "\n", + "With long parameter tables, graphs can get messy and hard to read.\n", + "Try using `visualize(..., compact=True)`.\n", + "\n", + "The `compact=True` option to yields a much more compact representation.\n", + "Instead of drawing every intermediate result and provider for each parameter, we then represent each parameter-dependent result as a single \"3D box\" node, representing all nodes for different values of the respective parameter.\n", + "\n", + "
" ] }, { @@ -224,13 +256,21 @@ "params = {ScaleFactor: 2.0}\n", "\n", "# 3.b Parameter table\n", - "run_ids = [102, 103, 104]\n", - "sample = ['diamond', 'graphite', 'graphite']\n", + "run_ids = [102, 103, 104, 105]\n", + "sample = ['diamond', 'graphite', 'graphite', 'graphite']\n", "filenames = [f'file{i}.txt' for i in run_ids]\n", "param_table = sciline.ParamTable(\n", " RunID, {Filename: filenames, Material: sample}, index=run_ids\n", ")\n", - "\n", + "param_table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "# 3.c Setup pipeline\n", "pipeline = sciline.Pipeline(providers, params=params)\n", "pipeline.set_param_table(param_table)" @@ -401,7 +441,7 @@ "\n", "pl = sl.Pipeline([gather1, gather2, product])\n", "pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 4, 9]}))\n", - "pl.set_param_table(sl.ParamTable(Row2, {Param2: [1, 2, 3]}))\n", + "pl.set_param_table(sl.ParamTable(Row2, {Param2: [1, 2]}))\n", "\n", "pl.visualize(List2)" ] diff --git a/src/sciline/param_table.py b/src/sciline/param_table.py index 69d80766..585f7e8e 100644 --- a/src/sciline/param_table.py +++ b/src/sciline/param_table.py @@ -72,3 +72,15 @@ def __len__(self) -> int: def __repr__(self) -> str: return f"ParamTable(row_dim={self.row_dim}, columns={self._columns})" + + def _repr_html_(self) -> str: + return ( + f"" + + "".join(f"" for k in self._columns.keys()) + + "" + + "".join( + f"" + "".join(f"" for v in row) + "" + for idx, row in zip(self.index, zip(*self._columns.values())) + ) + + "
{self.row_dim.__name__}{k.__name__}
{idx}{v}
" + ) diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index f79f57ca..a55c92c3 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -19,7 +19,7 @@ from .typing import Graph, Item, Key -def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph: +def to_graphviz(graph: Graph, compact: bool = False, **kwargs: Any) -> Digraph: """ Convert output of :py:class:`sciline.Pipeline.get_graph` to a graphviz graph. @@ -27,11 +27,14 @@ def to_graphviz(graph: Graph, **kwargs: Any) -> Digraph: ---------- graph: Output of :py:class:`sciline.Pipeline.get_graph`. + compact: + If True, parameter-table-dependent branches are collapsed into a single copy + of the branch. Recommendend for large graphs with long parameter tables. kwargs: Keyword arguments passed to :py:class:`graphviz.Digraph`. """ dot = Digraph(strict=True, **kwargs) - for p, (p_name, args, ret) in _format_graph(graph).items(): + for p, (p_name, args, ret) in _format_graph(graph, compact=compact).items(): if '(' in ret: dot.node(ret, ret, shape='box3d') else: @@ -61,34 +64,34 @@ def _qualname(obj: Any) -> Any: ) -def _format_graph(graph: Graph) -> Dict[str, Tuple[str, List[str], str]]: +def _format_graph(graph: Graph, compact: bool) -> Dict[str, Tuple[str, List[str], str]]: return { - _format_provider(provider, ret): ( + _format_provider(provider, ret, compact=compact): ( _qualname(provider), - [_format_type(a) for a in args], - _format_type(ret), + [_format_type(a, compact=compact) for a in args], + _format_type(ret, compact=compact), ) for ret, (provider, args) in graph.items() } -def _format_provider(provider: Callable[..., Any], ret: Key) -> str: - return f'{_qualname(provider)}_{_format_type(ret)}' +def _format_provider(provider: Callable[..., Any], ret: Key, compact: bool) -> str: + return f'{_qualname(provider)}_{_format_type(ret, compact=compact)}' T = TypeVar('T') def _extract_type_and_labels( - key: Union[Item[T], Type[T]] + key: Union[Item[T], Type[T]], compact: bool ) -> Tuple[Type[T], List[type]]: if isinstance(key, Item): label = key.label - return key.tp, [lb.tp for lb in label] + return key.tp, [lb.tp if compact else (lb.tp, lb.index) for lb in label] return key, [] -def _format_type(tp: Key) -> str: +def _format_type(tp: Key, compact: bool = False) -> str: """ Helper for _format_graph. @@ -97,14 +100,20 @@ def _format_type(tp: Key) -> str: We may make this configurable in the future. """ - tp, labels = _extract_type_and_labels(tp) + tp, labels = _extract_type_and_labels(tp, compact=compact) def get_base(tp: type) -> str: return tp.__name__ if hasattr(tp, '__name__') else str(tp).split('.')[-1] + def format_label(label): + if isinstance(label, tuple): + tp, index = label + return f'{get_base(tp)}={index}' + return get_base(label) + def with_labels(base: str) -> str: if labels: - return f'{base}({", ".join([get_base(l) for l in labels])})' + return f'{base}({", ".join([format_label(l) for l in labels])})' return base if (origin := get_origin(tp)) is not None: From d001116da731f81fde5e89d077149f8579042e7d Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 09:35:29 +0200 Subject: [PATCH 63/87] Fix type hints --- src/sciline/visualize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index a55c92c3..501888d1 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -84,7 +84,7 @@ def _format_provider(provider: Callable[..., Any], ret: Key, compact: bool) -> s def _extract_type_and_labels( key: Union[Item[T], Type[T]], compact: bool -) -> Tuple[Type[T], List[type]]: +) -> Tuple[Type[T], List[Union[type, Tuple[type, Any]]]]: if isinstance(key, Item): label = key.label return key.tp, [lb.tp if compact else (lb.tp, lb.index) for lb in label] @@ -105,7 +105,7 @@ def _format_type(tp: Key, compact: bool = False) -> str: def get_base(tp: type) -> str: return tp.__name__ if hasattr(tp, '__name__') else str(tp).split('.')[-1] - def format_label(label): + def format_label(label: Union[type, Tuple[type, Any]]) -> str: if isinstance(label, tuple): tp, index = label return f'{get_base(tp)}={index}' From cc58b6f950914c5f81536965410acc14f9777436 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 10:43:13 +0200 Subject: [PATCH 64/87] Do not draw box3d unless compact --- docs/user-guide/parameter-tables.ipynb | 1 - src/sciline/visualize.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/user-guide/parameter-tables.ipynb b/docs/user-guide/parameter-tables.ipynb index 53852a1c..77b5af77 100644 --- a/docs/user-guide/parameter-tables.ipynb +++ b/docs/user-guide/parameter-tables.ipynb @@ -163,7 +163,6 @@ "source": [ "Nodes that depend on values from a parameter table are drawn as \"3D boxes\" with the parameter index name (the row dimension of the parameter table) and value given in parenthesis.\n", "The dashed arrow indicates and internal transformation that gathers result from each branch and combines them into a single output, here `Series[RunID, Result]`.\n", - "Note how this transitions from a \"3D box\" to a \"2D box\" as the series of keys in the graph is reduced into a single key.\n", "\n", "
\n", "\n", diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index 501888d1..ee7e9821 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -35,7 +35,7 @@ def to_graphviz(graph: Graph, compact: bool = False, **kwargs: Any) -> Digraph: """ dot = Digraph(strict=True, **kwargs) for p, (p_name, args, ret) in _format_graph(graph, compact=compact).items(): - if '(' in ret: + if '(' in ret and '=' not in ret: dot.node(ret, ret, shape='box3d') else: dot.node(ret, ret, shape='rectangle') From 0f109fb557d2aad9e1ca238a3c42e42c9bfe485a Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 10:46:32 +0200 Subject: [PATCH 65/87] Use examples in docs that would pass mypy --- docs/user-guide/getting-started.ipynb | 6 +++--- docs/user-guide/parameter-tables.ipynb | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/user-guide/getting-started.ipynb b/docs/user-guide/getting-started.ipynb index b00c5c61..981f0d2a 100644 --- a/docs/user-guide/getting-started.ipynb +++ b/docs/user-guide/getting-started.ipynb @@ -119,19 +119,19 @@ "def load(filename: Filename) -> RawData:\n", " \"\"\"Load the data from the filename.\"\"\"\n", " data = _fake_filesytem[filename]\n", - " return {'data': data, 'meta': {'filename': filename}}\n", + " return RawData({'data': data, 'meta': {'filename': filename}})\n", "\n", "\n", "def clean(raw_data: RawData) -> CleanedData:\n", " \"\"\"Clean the data, removing NaNs.\"\"\"\n", " import math\n", "\n", - " return [x for x in raw_data['data'] if not math.isnan(x)]\n", + " return CleanedData([x for x in raw_data['data'] if not math.isnan(x)])\n", "\n", "\n", "def process(data: CleanedData, param: ScaleFactor) -> Result:\n", " \"\"\"Process the data, multiplying the sum by the scale factor.\"\"\"\n", - " return sum(data) * param\n", + " return Result(sum(data) * param)\n", "\n", "\n", "# 3. Create pipeline\n", diff --git a/docs/user-guide/parameter-tables.ipynb b/docs/user-guide/parameter-tables.ipynb index 77b5af77..e28c95b8 100644 --- a/docs/user-guide/parameter-tables.ipynb +++ b/docs/user-guide/parameter-tables.ipynb @@ -57,19 +57,19 @@ " \"\"\"Load the data from the filename.\"\"\"\n", "\n", " data = _fake_filesytem[filename]\n", - " return {'data': data, 'meta': {'filename': filename}}\n", + " return RawData({'data': data, 'meta': {'filename': filename}})\n", "\n", "\n", "def clean(raw_data: RawData) -> CleanedData:\n", " \"\"\"Clean the data, removing NaNs.\"\"\"\n", " import math\n", "\n", - " return [x for x in raw_data['data'] if not math.isnan(x)]\n", + " return CleanedData([x for x in raw_data['data'] if not math.isnan(x)])\n", "\n", "\n", "def process(data: CleanedData, param: ScaleFactor) -> Result:\n", " \"\"\"Process the data, multiplying the sum by the scale factor.\"\"\"\n", - " return sum(data) * param\n", + " return Result(sum(data) * param)\n", "\n", "\n", "# 3. Create pipeline\n", From 7af8fc5c7830a98d0aa3c1d822569df860bf144f Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 10:53:45 +0200 Subject: [PATCH 66/87] Forbid providing Series directly --- src/sciline/pipeline.py | 6 ++++++ tests/pipeline_with_index_test.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index acd862e5..419cce90 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -336,6 +336,12 @@ def _set_provider( # isinstance does not work here and types.NoneType available only in 3.10+ if key == type(None): # noqa: E721 raise ValueError(f'Provider {provider} returning `None` is not allowed') + if get_origin(key) == Series: + raise ValueError( + f'Provider {provider} returning a sciline.Series is not allowed. ' + 'Series is a special container reserved for use in conjunction with ' + 'sciline.ParamTable and may not be provided directly.' + ) if (origin := get_origin(key)) is not None: subproviders = self._subproviders.setdefault(origin, {}) args = get_args(key) diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_index_test.py index 70f9ea61..fc4c6968 100644 --- a/tests/pipeline_with_index_test.py +++ b/tests/pipeline_with_index_test.py @@ -67,6 +67,23 @@ def process(x: float) -> str: ) +def test_creating_pipeline_with_provider_of_series_raises() -> None: + series = sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0}) + + def make_series() -> sl.Series[int, float]: + return series + + with pytest.raises(ValueError): + sl.Pipeline([make_series]) + + +def test_creating_pipeline_with_series_param_raises() -> None: + series = sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0}) + + with pytest.raises(ValueError): + sl.Pipeline([], params={sl.Series[int, float]: series}) + + def test_explicit_index_of_param_table_is_forwarded_correctly() -> None: def process(x: float) -> int: return int(x) From 14f14397594b302e299d980b77f348a31a23394b Mon Sep 17 00:00:00 2001 From: Simon Heybrock <12912489+SimonHeybrock@users.noreply.github.com> Date: Tue, 8 Aug 2023 14:04:38 +0200 Subject: [PATCH 67/87] Apply suggestions from code review Co-authored-by: Jan-Lukas Wynen --- src/sciline/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 419cce90..4090e66a 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -340,7 +340,7 @@ def _set_provider( raise ValueError( f'Provider {provider} returning a sciline.Series is not allowed. ' 'Series is a special container reserved for use in conjunction with ' - 'sciline.ParamTable and may not be provided directly.' + 'sciline.ParamTable and must not be provided directly.' ) if (origin := get_origin(key)) is not None: subproviders = self._subproviders.setdefault(origin, {}) @@ -437,7 +437,7 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: # Identify nodes in the graph that need to be duplicated as they lie in the # path to a parameter from a table. In the case of grouping, note that the # ungrouped graph (including duplication of nodes) will have been built by a - # prior call to _build_series, so instead of duplicated everything until the + # prior call to _build_series, so instead of duplicating everything until the # param table is reached, we only duplicate until the node that is performing # the grouping. grouper: Grouper From e4c84728b0d820ae66095699a14dd4d013a667e9 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 14:13:55 +0200 Subject: [PATCH 68/87] Use py38 syntax --- src/sciline/pipeline.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 4090e66a..8484b947 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -55,7 +55,9 @@ class AmbiguousProvider(Exception): """Raised when multiple providers are found for a type.""" -def _indexed_key(index_name: Any, i: int, value_name: Type[T] | Item[T]) -> Item[T]: +def _indexed_key( + index_name: Any, i: int, value_name: Union[Type[T], Item[T]] +) -> Item[T]: label = Label(index_name, i) if isinstance(value_name, Item): return Item(value_name.label + (label,), value_name.tp) From 21984677842e0720d73bbfedfbd3fabb4f92a15b Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 14:16:42 +0200 Subject: [PATCH 69/87] Rename param name --- src/sciline/pipeline.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 8484b947..f1f739b3 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -97,16 +97,16 @@ def _bind_free_typevars(tp: TypeVar | Key, bound: Dict[TypeVar, Key]) -> Key: def _find_all_paths( - graph: Mapping[T, Collection[T]], start: T, end: T + dependencies: Mapping[T, Collection[T]], start: T, end: T ) -> List[List[T]]: """Find all paths from start to end in a DAG.""" if start == end: return [[start]] - if start not in graph: + if start not in dependencies: return [] paths = [] - for node in graph[start]: - for path in _find_all_paths(graph, node, end): + for node in dependencies[start]: + for path in _find_all_paths(dependencies, node, end): paths.append([start] + path) return paths From ade5f50e9aff8282a03fb4e6d3ce4ff7964d346f Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 14:20:00 +0200 Subject: [PATCH 70/87] Use cleared field name --- src/sciline/pipeline.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index f1f739b3..c76d1390 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -244,7 +244,7 @@ def __init__( self._providers: Dict[Key, Provider] = {} self._subproviders: Dict[type, Dict[Tuple[Key | TypeVar, ...], Provider]] = {} self._param_tables: Dict[Key, ParamTable] = {} - self._param_series: Dict[Key, Key] = {} + self._param_index_name: Dict[Key, Key] = {} for provider in providers or []: self.insert(provider) for tp, param in (params or {}).items(): @@ -320,11 +320,11 @@ def set_param_table(self, params: ParamTable) -> None: if params.row_dim in self._param_tables: raise ValueError(f'Parameter table for {params.row_dim} already set') for param_name in params: - if param_name in self._param_series: + if param_name in self._param_index_name: raise ValueError(f'Parameter {param_name} already set') self._param_tables[params.row_dim] = params for param_name in params: - self._param_series[param_name] = params.row_dim + self._param_index_name[param_name] = params.row_dim for param_name, values in params.items(): for index, label in zip(params.index, values): self._set_provider( @@ -404,8 +404,8 @@ def build( stack: List[Union[Type[T], Item[T]]] = [tp] while stack: tp = stack.pop() - if search_param_tables and tp in self._param_series: - graph[tp] = (_param_sentinel, (self._param_series[tp],)) + if search_param_tables and tp in self._param_index_name: + graph[tp] = (_param_sentinel, (self._param_index_name[tp],)) continue if get_origin(tp) == Series: graph.update(self._build_series(tp)) # type: ignore[arg-type] @@ -444,12 +444,12 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: # the grouping. grouper: Grouper if ( - label_name not in self._param_series + label_name not in self._param_index_name and (params := self._param_tables.get(label_name)) is not None ): path = _find_nodes_in_paths(subgraph, value_type, label_name) grouper = NoGrouping(index=params.index) - elif (index_name := self._param_series.get(label_name)) is not None: + elif (index_name := self._param_index_name.get(label_name)) is not None: params = self._param_tables[index_name] labels = params[label_name] grouping_node = self._find_grouping_node(index_name, subgraph) From e7ff68583893ca845f456c7b460ad9fc20376f8f Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 14:20:43 +0200 Subject: [PATCH 71/87] Remove impl. detail test --- tests/graph_test.py | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 tests/graph_test.py diff --git a/tests/graph_test.py b/tests/graph_test.py deleted file mode 100644 index 65705b20..00000000 --- a/tests/graph_test.py +++ /dev/null @@ -1,10 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from sciline.pipeline import _find_all_paths - - -def test_find_all_paths() -> None: - graph = {"D": ["B", "C"], "C": ["A"], "B": ["A"]} - assert _find_all_paths(graph, "D", "A") == [["D", "B", "A"], ["D", "C", "A"]] - assert _find_all_paths(graph, "B", "C") == [] - assert _find_all_paths(graph, "B", "A") == [["B", "A"]] From 3e9d4ac6a9de5fe433f7548dba4d46e897b4093d Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 8 Aug 2023 15:39:44 +0200 Subject: [PATCH 72/87] Clarify docstring --- src/sciline/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index c76d1390..aa8fecef 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -115,8 +115,8 @@ def _find_nodes_in_paths( graph: Mapping[T, Tuple[Callable[..., Any], Collection[T]]], start: T, end: T ) -> List[T]: """ - Find all nodes that need to be duplicated since they depend on a value from a - param table. + Helper for Pipeline. Finds all nodes that need to be duplicated since they depend + on a value from a param table. """ # 0 is the provider, 1 is the args dependencies = {k: v[1] for k, v in graph.items()} From 243ff9d1136cdd4832e86fe441c680abff1b5c39 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 9 Aug 2023 06:15:00 +0200 Subject: [PATCH 73/87] Do not rely on string matching --- src/sciline/visualize.py | 37 +++++++++++++++++++++++-------------- tests/visualize_test.py | 6 +++--- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index ee7e9821..37440193 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from dataclasses import dataclass from typing import ( Any, Callable, @@ -19,6 +20,12 @@ from .typing import Graph, Item, Key +@dataclass +class Node: + name: str + collapsed: bool = False + + def to_graphviz(graph: Graph, compact: bool = False, **kwargs: Any) -> Digraph: """ Convert output of :py:class:`sciline.Pipeline.get_graph` to a graphviz graph. @@ -35,10 +42,7 @@ def to_graphviz(graph: Graph, compact: bool = False, **kwargs: Any) -> Digraph: """ dot = Digraph(strict=True, **kwargs) for p, (p_name, args, ret) in _format_graph(graph, compact=compact).items(): - if '(' in ret and '=' not in ret: - dot.node(ret, ret, shape='box3d') - else: - dot.node(ret, ret, shape='rectangle') + dot.node(ret.name, ret.name, shape='box3d' if ret.collapsed else 'rectangle') # Do not draw dummy providers created by Pipeline when setting instances if p_name in ( f'{_qualname(Pipeline.__setitem__)}..', @@ -49,12 +53,12 @@ def to_graphviz(graph: Graph, compact: bool = False, **kwargs: Any) -> Digraph: # a dict if p_name == _qualname(SeriesProvider): for arg in args: - dot.edge(arg, ret, style='dashed') + dot.edge(arg.name, ret.name, style='dashed') else: dot.node(p, p_name, shape='ellipse') for arg in args: - dot.edge(arg, p) - dot.edge(p, ret) + dot.edge(arg.name, p) + dot.edge(p, ret.name) return dot @@ -64,7 +68,9 @@ def _qualname(obj: Any) -> Any: ) -def _format_graph(graph: Graph, compact: bool) -> Dict[str, Tuple[str, List[str], str]]: +def _format_graph( + graph: Graph, compact: bool +) -> Dict[str, Tuple[str, List[Node], Node]]: return { _format_provider(provider, ret, compact=compact): ( _qualname(provider), @@ -76,7 +82,7 @@ def _format_graph(graph: Graph, compact: bool) -> Dict[str, Tuple[str, List[str] def _format_provider(provider: Callable[..., Any], ret: Key, compact: bool) -> str: - return f'{_qualname(provider)}_{_format_type(ret, compact=compact)}' + return f'{_qualname(provider)}_{_format_type(ret, compact=compact).name}' T = TypeVar('T') @@ -91,7 +97,7 @@ def _extract_type_and_labels( return key, [] -def _format_type(tp: Key, compact: bool = False) -> str: +def _format_type(tp: Key, compact: bool = False) -> Node: """ Helper for _format_graph. @@ -111,13 +117,16 @@ def format_label(label: Union[type, Tuple[type, Any]]) -> str: return f'{get_base(tp)}={index}' return get_base(label) - def with_labels(base: str) -> str: + def with_labels(base: str) -> Node: if labels: - return f'{base}({", ".join([format_label(l) for l in labels])})' - return base + return Node( + name=f'{base}({", ".join([format_label(l) for l in labels])})', + collapsed=True, + ) + return Node(name=base) if (origin := get_origin(tp)) is not None: - params = [_format_type(param) for param in get_args(tp)] + params = [_format_type(param).name for param in get_args(tp)] return with_labels(f'{get_base(origin)}[{", ".join(params)}]') else: return with_labels(get_base(tp)) diff --git a/tests/visualize_test.py b/tests/visualize_test.py index 663c6dfe..8290cbc7 100644 --- a/tests/visualize_test.py +++ b/tests/visualize_test.py @@ -30,6 +30,6 @@ class B(Generic[T]): class SubA(A[T]): pass - assert sl.visualize._format_type(A[float]) == 'A[float]' - assert sl.visualize._format_type(SubA[float]) == 'SubA[float]' - assert sl.visualize._format_type(B[float]) == 'B[float]' + assert sl.visualize._format_type(A[float]).name == 'A[float]' + assert sl.visualize._format_type(SubA[float]).name == 'SubA[float]' + assert sl.visualize._format_type(B[float]).name == 'B[float]' From 98105818220e3afd3035057faf89c04041c9811c Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 9 Aug 2023 06:16:08 +0200 Subject: [PATCH 74/87] Rename test file --- ...eline_with_index_test.py => pipeline_with_param_table_test.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{pipeline_with_index_test.py => pipeline_with_param_table_test.py} (100%) diff --git a/tests/pipeline_with_index_test.py b/tests/pipeline_with_param_table_test.py similarity index 100% rename from tests/pipeline_with_index_test.py rename to tests/pipeline_with_param_table_test.py From 0e0e75d1a87174688c73635025eb87da2b496871 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 9 Aug 2023 06:50:31 +0200 Subject: [PATCH 75/87] WIP refactor --- src/sciline/pipeline.py | 129 ++++++++++++++++++++++------------------ 1 file changed, 72 insertions(+), 57 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index aa8fecef..9e1131b9 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -143,54 +143,99 @@ def __call__(self, key: Any) -> Callable[..., bool]: def get_grouping(self, key: Any, group: Any) -> Any: ... + def duplicate(self, key: Any, value: Any) -> Dict[Key, Any]: + pass + class NoGrouping(Generic[IndexType]): """Helper for rewriting the graph to map over a given index.""" - def __init__(self, index: Iterable[IndexType]) -> None: - self._index = index + def __init__(self, param_table: ParamTable, graph_template: Graph) -> None: + self._graph_template = graph_template + self._index = param_table.index + self._index_name = param_table.row_dim + self._root = next(iter(graph_template)) + self._path = _find_nodes_in_paths(graph_template, self._root, self._index_name) def __iter__(self) -> Iterator[IndexType]: return iter(self._index) - def __call__(self, key: Any) -> Callable[..., bool]: - return _yes + def duplicate(self, key: Any, value: Any, get_provider) -> Dict[Key, Any]: + return duplicate_node( + key, value, get_provider, self._index_name, self._index, self._path + ) - def get_grouping(self, key: Any, group: int) -> None: - return None + +def duplicate_node( + key: Any, value: Any, get_provider, index_name: Any, index: Any, path +) -> Dict[Key, Any]: + graph = {} + for idx in index: + provider, args = value + subkey = _indexed_key(index_name, idx, key) + if provider == _param_sentinel: + provider, _ = get_provider(subkey) + args = () + args_with_index = tuple( + _indexed_key(index_name, idx, arg) if arg in path else arg for arg in args + ) + graph[subkey] = (provider, args_with_index) + return graph class GroupBy(Generic[IndexType, LabelType]): """Helper for rewriting the graph to group by a given index.""" def __init__( - self, - *, - grouping_node: type, - index: Iterable[IndexType], - labels: Iterable[LabelType], + self, param_table: ParamTable, graph_template: Graph, label_name ) -> None: - self.grouping_node = grouping_node + self._index_name = param_table.row_dim + self._label_name = label_name + self._root = next(iter(graph_template)) + self._grouping_node = self._find_grouping_node(self._index_name, graph_template) + self._path = _find_nodes_in_paths( + graph_template, self._root, self._grouping_node + ) self._index: Dict[LabelType, List[IndexType]] = defaultdict(list) - for idx, label in zip(index, labels): + for idx, label in zip(param_table.index, param_table[label_name]): self._index[label].append(idx) def __iter__(self) -> Iterator[LabelType]: return iter(self._index) - def __call__(self, key: Any) -> Any: - return self.in_group if key == self.grouping_node else _yes - - def get_grouping(self, key: Any, group: LabelType) -> Optional[List[IndexType]]: - if key != self.grouping_node: - return None - return self._index[group] - def in_group(self, arg: Item[Any], group: LabelType) -> bool: if len(arg.label) != 1: raise ValueError(f'Cannot group with multi-index label {arg.label}') return arg.label[0].index in self._index[group] + def duplicate(self, key: Any, value: Any, get_provider) -> Dict[Key, Any]: + if key != self._grouping_node: + return duplicate_node( + key, value, get_provider, self._label_name, self._index, self._path + ) + graph = {} + for index in self._index: + provider, args = value + subkey = _indexed_key(self._label_name, index, key) + args_with_index = tuple( + _indexed_key(self._label_name, index, arg) if arg in self._path else arg + for arg in args + if self.in_group(arg, index) + ) + provider = provider.restrict(self._index[index]) + graph[subkey] = (provider, args_with_index) + return graph + + def _find_grouping_node(self, index_name: Key, subgraph: Graph) -> type: + ends: List[type] = [] + for key in subgraph: + if get_origin(key) == Series and get_args(key)[0] == index_name: + # Because of the succeeded get_origin we know it is a type + ends.append(key) # type: ignore[arg-type] + if len(ends) == 1: + return ends[0] + raise ValueError(f"Could not find unique grouping node, found {ends}") + class SeriesProvider(Generic[KeyType, ValueType]): """ @@ -447,15 +492,12 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: label_name not in self._param_index_name and (params := self._param_tables.get(label_name)) is not None ): - path = _find_nodes_in_paths(subgraph, value_type, label_name) - grouper = NoGrouping(index=params.index) + grouper = NoGrouping(param_table=params, graph_template=subgraph) elif (index_name := self._param_index_name.get(label_name)) is not None: - params = self._param_tables[index_name] - labels = params[label_name] - grouping_node = self._find_grouping_node(index_name, subgraph) - path = _find_nodes_in_paths(subgraph, value_type, grouping_node) grouper = GroupBy( - index=params.index, labels=labels, grouping_node=grouping_node + param_table=self._param_tables[index_name], + graph_template=subgraph, + label_name=label_name, ) else: raise KeyError(f'No parameter table found for label {label_name}') @@ -469,39 +511,12 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: # Step 3: # Duplicate nodes, replacing keys with indexed keys. for key, value in subgraph.items(): - if key in path: - in_group = grouper(key) - for index in grouper: - provider, args = value - subkey = _indexed_key(label_name, index, key) - if provider == _param_sentinel: - provider, _ = self._get_provider(subkey) - args = () - args_with_index = tuple( - _indexed_key(label_name, index, arg) if arg in path else arg - for arg in args - if in_group(arg, index) - ) - if isinstance(provider, SeriesProvider): - # mypy does not detect that SeriesProducer is Callable? - provider = provider.restrict( # type: ignore[unreachable] - grouper.get_grouping(key, index) - ) - graph[subkey] = (provider, args_with_index) + if key in grouper._path: + graph.update(grouper.duplicate(key, value, self._get_provider)) else: graph[key] = value return graph - def _find_grouping_node(self, index_name: Key, subgraph: Graph) -> type: - ends: List[type] = [] - for key in subgraph: - if get_origin(key) == Series and get_args(key)[0] == index_name: - # Because of the succeeded get_origin we know it is a type - ends.append(key) # type: ignore[arg-type] - if len(ends) == 1: - return ends[0] - raise ValueError(f"Could not find unique grouping node, found {ends}") - @overload def compute(self, tp: Type[T]) -> T: ... From a4ba895a70c14f79cf4607be8d8fc9bcfa749a64 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 9 Aug 2023 08:47:32 +0200 Subject: [PATCH 76/87] Simplify --- src/sciline/pipeline.py | 54 +++++++++++++++-------------------------- 1 file changed, 20 insertions(+), 34 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 9e1131b9..ec46312c 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -112,12 +112,13 @@ def _find_all_paths( def _find_nodes_in_paths( - graph: Mapping[T, Tuple[Callable[..., Any], Collection[T]]], start: T, end: T + graph: Mapping[T, Tuple[Callable[..., Any], Collection[T]]], end: T ) -> List[T]: """ Helper for Pipeline. Finds all nodes that need to be duplicated since they depend on a value from a param table. """ + start = next(iter(graph)) # 0 is the provider, 1 is the args dependencies = {k: v[1] for k, v in graph.items()} paths = _find_all_paths(dependencies, start, end) @@ -154,12 +155,14 @@ def __init__(self, param_table: ParamTable, graph_template: Graph) -> None: self._graph_template = graph_template self._index = param_table.index self._index_name = param_table.row_dim - self._root = next(iter(graph_template)) - self._path = _find_nodes_in_paths(graph_template, self._root, self._index_name) + self._path = _find_nodes_in_paths(graph_template, self._index_name) def __iter__(self) -> Iterator[IndexType]: return iter(self._index) + def make_series(self, *vals: Any) -> Series[IndexType, Any]: + return Series(self._index, vals) + def duplicate(self, key: Any, value: Any, get_provider) -> Dict[Key, Any]: return duplicate_node( key, value, get_provider, self._index_name, self._index, self._path @@ -189,13 +192,9 @@ class GroupBy(Generic[IndexType, LabelType]): def __init__( self, param_table: ParamTable, graph_template: Graph, label_name ) -> None: - self._index_name = param_table.row_dim self._label_name = label_name - self._root = next(iter(graph_template)) - self._grouping_node = self._find_grouping_node(self._index_name, graph_template) - self._path = _find_nodes_in_paths( - graph_template, self._root, self._grouping_node - ) + self._group_node = self._find_grouping_node(param_table.row_dim, graph_template) + self._path = _find_nodes_in_paths(graph_template, self._group_node) self._index: Dict[LabelType, List[IndexType]] = defaultdict(list) for idx, label in zip(param_table.index, param_table[label_name]): self._index[label].append(idx) @@ -203,27 +202,25 @@ def __init__( def __iter__(self) -> Iterator[LabelType]: return iter(self._index) - def in_group(self, arg: Item[Any], group: LabelType) -> bool: - if len(arg.label) != 1: - raise ValueError(f'Cannot group with multi-index label {arg.label}') - return arg.label[0].index in self._index[group] - def duplicate(self, key: Any, value: Any, get_provider) -> Dict[Key, Any]: - if key != self._grouping_node: + if key != self._group_node: return duplicate_node( key, value, get_provider, self._label_name, self._index, self._path ) graph = {} + provider, args = value for index in self._index: - provider, args = value + labels = self._index[index] + if set(labels) - set(provider._labels): + raise ValueError(f'{labels} is not a subset of {provider._labels}') subkey = _indexed_key(self._label_name, index, key) - args_with_index = tuple( - _indexed_key(self._label_name, index, arg) if arg in self._path else arg - for arg in args - if self.in_group(arg, index) - ) - provider = provider.restrict(self._index[index]) - graph[subkey] = (provider, args_with_index) + selected = { + label: arg + for label, arg in zip(provider._labels, args) + if label in labels + } + split_provider = SeriesProvider(selected, provider._row_dim) + graph[subkey] = (split_provider, tuple(selected.values())) return graph def _find_grouping_node(self, index_name: Key, subgraph: Graph) -> type: @@ -250,17 +247,6 @@ def __init__(self, labels: Iterable[KeyType], row_dim: type) -> None: def __call__(self, *vals: ValueType) -> Series[KeyType, ValueType]: return Series(self._row_dim, dict(zip(self._labels, vals))) - def restrict( - self, labels: Optional[Iterable[KeyType]] - ) -> SeriesProvider[KeyType, ValueType]: - if labels is None: - return self - if set(labels) - set(self._labels): - raise ValueError(f'{labels} is not a subset of {self._labels}') - # Ensure that labels are in the same order as in the original series - labels = [label for label in self._labels if label in labels] - return SeriesProvider(labels, self._row_dim) - class _param_sentinel: ... From 395f0889c370dd417060cc117c90eb7d68a4dc4c Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 9 Aug 2023 08:58:04 +0200 Subject: [PATCH 77/87] Cleanup --- src/sciline/pipeline.py | 58 +++++++++++++++-------------------------- 1 file changed, 21 insertions(+), 37 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index ec46312c..f29d556b 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -10,9 +10,7 @@ Dict, Generic, Iterable, - Iterator, List, - Literal, Mapping, Optional, Protocol, @@ -128,44 +126,30 @@ def _find_nodes_in_paths( return list(nodes) -def _yes(*_: Any) -> Literal[True]: - return True - - class Grouper(Protocol): """Helper protocol for rewriting graphs.""" - def __iter__(self) -> Iterator[Any]: - ... - - def __call__(self, key: Any) -> Callable[..., bool]: - ... - - def get_grouping(self, key: Any, group: Any) -> Any: + def __contains__(self, key: Any) -> bool: ... def duplicate(self, key: Any, value: Any) -> Dict[Key, Any]: - pass + ... class NoGrouping(Generic[IndexType]): """Helper for rewriting the graph to map over a given index.""" def __init__(self, param_table: ParamTable, graph_template: Graph) -> None: - self._graph_template = graph_template - self._index = param_table.index + self.index = param_table.index self._index_name = param_table.row_dim self._path = _find_nodes_in_paths(graph_template, self._index_name) - def __iter__(self) -> Iterator[IndexType]: - return iter(self._index) - - def make_series(self, *vals: Any) -> Series[IndexType, Any]: - return Series(self._index, vals) + def __contains__(self, key: Any) -> bool: + return key in self._path def duplicate(self, key: Any, value: Any, get_provider) -> Dict[Key, Any]: return duplicate_node( - key, value, get_provider, self._index_name, self._index, self._path + key, value, get_provider, self._index_name, self.index, self._path ) @@ -195,28 +179,28 @@ def __init__( self._label_name = label_name self._group_node = self._find_grouping_node(param_table.row_dim, graph_template) self._path = _find_nodes_in_paths(graph_template, self._group_node) - self._index: Dict[LabelType, List[IndexType]] = defaultdict(list) + self.index: Dict[LabelType, List[IndexType]] = defaultdict(list) for idx, label in zip(param_table.index, param_table[label_name]): - self._index[label].append(idx) + self.index[label].append(idx) - def __iter__(self) -> Iterator[LabelType]: - return iter(self._index) + def __contains__(self, key: Any) -> bool: + return key in self._path def duplicate(self, key: Any, value: Any, get_provider) -> Dict[Key, Any]: if key != self._group_node: return duplicate_node( - key, value, get_provider, self._label_name, self._index, self._path + key, value, get_provider, self._label_name, self.index, self._path ) graph = {} provider, args = value - for index in self._index: - labels = self._index[index] - if set(labels) - set(provider._labels): - raise ValueError(f'{labels} is not a subset of {provider._labels}') + for index in self.index: + labels = self.index[index] + if set(labels) - set(provider.labels): + raise ValueError(f'{labels} is not a subset of {provider.labels}') subkey = _indexed_key(self._label_name, index, key) selected = { label: arg - for label, arg in zip(provider._labels, args) + for label, arg in zip(provider.labels, args) if label in labels } split_provider = SeriesProvider(selected, provider._row_dim) @@ -241,11 +225,11 @@ class SeriesProvider(Generic[KeyType, ValueType]): """ def __init__(self, labels: Iterable[KeyType], row_dim: type) -> None: - self._labels = labels + self.labels = labels self._row_dim = row_dim def __call__(self, *vals: ValueType) -> Series[KeyType, ValueType]: - return Series(self._row_dim, dict(zip(self._labels, vals))) + return Series(self._row_dim, dict(zip(self.labels, vals))) class _param_sentinel: @@ -490,14 +474,14 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: graph: Graph = {} graph[tp] = ( - SeriesProvider(list(grouper), label_name), - tuple(_indexed_key(label_name, index, value_type) for index in grouper), + SeriesProvider(list(grouper.index), label_name), + tuple(_indexed_key(label_name, idx, value_type) for idx in grouper.index), ) # Step 3: # Duplicate nodes, replacing keys with indexed keys. for key, value in subgraph.items(): - if key in grouper._path: + if key in grouper: graph.update(grouper.duplicate(key, value, self._get_provider)) else: graph[key] = value From b08c9e75d10886d91e6dbb6edd69dc615aac246d Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 9 Aug 2023 09:22:47 +0200 Subject: [PATCH 78/87] More cleanup --- src/sciline/pipeline.py | 99 ++++++++++++++++++----------------------- 1 file changed, 43 insertions(+), 56 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index f29d556b..5e7fd9a9 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -13,7 +13,6 @@ List, Mapping, Optional, - Protocol, Tuple, Type, TypeVar, @@ -53,16 +52,6 @@ class AmbiguousProvider(Exception): """Raised when multiple providers are found for a type.""" -def _indexed_key( - index_name: Any, i: int, value_name: Union[Type[T], Item[T]] -) -> Item[T]: - label = Label(index_name, i) - if isinstance(value_name, Item): - return Item(value_name.label + (label,), value_name.tp) - else: - return Item((label,), value_name) - - def _is_compatible_type_tuple( requested: tuple[Key, ...], provided: tuple[Key | TypeVar, ...], @@ -126,51 +115,50 @@ def _find_nodes_in_paths( return list(nodes) -class Grouper(Protocol): - """Helper protocol for rewriting graphs.""" +class Grouper: + def __init__(self, index_name, index, path): + self._index_name = index_name + self.index = index + self._path = path def __contains__(self, key: Any) -> bool: - ... + return key in self._path - def duplicate(self, key: Any, value: Any) -> Dict[Key, Any]: - ... + def duplicate(self, key: Any, value: Any, get_provider) -> Dict[Key, Any]: + graph = {} + for idx in self.index: + provider, args = value + subkey = self.key(idx, key) + if provider == _param_sentinel: + provider, _ = get_provider(subkey) + args = () + args_with_index = tuple( + self.key(idx, arg) if arg in self else arg for arg in args + ) + graph[subkey] = (provider, args_with_index) + return graph + + def key(self, i: int, value_name: Union[Type[T], Item[T]]) -> Item[T]: + label = Label(self._index_name, i) + if isinstance(value_name, Item): + return Item(value_name.label + (label,), value_name.tp) + else: + return Item((label,), value_name) -class NoGrouping(Generic[IndexType]): +class NoGrouping(Grouper, Generic[IndexType]): """Helper for rewriting the graph to map over a given index.""" def __init__(self, param_table: ParamTable, graph_template: Graph) -> None: - self.index = param_table.index - self._index_name = param_table.row_dim - self._path = _find_nodes_in_paths(graph_template, self._index_name) - - def __contains__(self, key: Any) -> bool: - return key in self._path - - def duplicate(self, key: Any, value: Any, get_provider) -> Dict[Key, Any]: - return duplicate_node( - key, value, get_provider, self._index_name, self.index, self._path + index_name = param_table.row_dim + super().__init__( + index_name=index_name, + index=param_table.index, + path=_find_nodes_in_paths(graph_template, index_name), ) -def duplicate_node( - key: Any, value: Any, get_provider, index_name: Any, index: Any, path -) -> Dict[Key, Any]: - graph = {} - for idx in index: - provider, args = value - subkey = _indexed_key(index_name, idx, key) - if provider == _param_sentinel: - provider, _ = get_provider(subkey) - args = () - args_with_index = tuple( - _indexed_key(index_name, idx, arg) if arg in path else arg for arg in args - ) - graph[subkey] = (provider, args_with_index) - return graph - - -class GroupBy(Generic[IndexType, LabelType]): +class GroupBy(Grouper, Generic[IndexType, LabelType]): """Helper for rewriting the graph to group by a given index.""" def __init__( @@ -178,26 +166,25 @@ def __init__( ) -> None: self._label_name = label_name self._group_node = self._find_grouping_node(param_table.row_dim, graph_template) - self._path = _find_nodes_in_paths(graph_template, self._group_node) - self.index: Dict[LabelType, List[IndexType]] = defaultdict(list) + index: Dict[LabelType, List[IndexType]] = defaultdict(list) for idx, label in zip(param_table.index, param_table[label_name]): - self.index[label].append(idx) - - def __contains__(self, key: Any) -> bool: - return key in self._path + index[label].append(idx) + super().__init__( + index_name=label_name, + index=index, + path=_find_nodes_in_paths(graph_template, self._group_node), + ) def duplicate(self, key: Any, value: Any, get_provider) -> Dict[Key, Any]: if key != self._group_node: - return duplicate_node( - key, value, get_provider, self._label_name, self.index, self._path - ) + return super().duplicate(key, value, get_provider) graph = {} provider, args = value for index in self.index: labels = self.index[index] if set(labels) - set(provider.labels): raise ValueError(f'{labels} is not a subset of {provider.labels}') - subkey = _indexed_key(self._label_name, index, key) + subkey = self.key(index, key) selected = { label: arg for label, arg in zip(provider.labels, args) @@ -475,7 +462,7 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: graph: Graph = {} graph[tp] = ( SeriesProvider(list(grouper.index), label_name), - tuple(_indexed_key(label_name, idx, value_type) for idx in grouper.index), + tuple(grouper.key(idx, value_type) for idx in grouper.index), ) # Step 3: From 95781f26b7e0e0c0c06a23ae74850e047d1bfd4a Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 9 Aug 2023 11:46:30 +0200 Subject: [PATCH 79/87] Fix type hints --- src/sciline/pipeline.py | 48 +++++++++++++++++++++++++---------------- src/sciline/typing.py | 2 +- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 5e7fd9a9..5e117c36 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -115,17 +115,22 @@ def _find_nodes_in_paths( return list(nodes) -class Grouper: - def __init__(self, index_name, index, path): +class Grouper(Generic[IndexType]): + def __init__(self, index_name: type, index: Iterable[IndexType], path: List[Key]): self._index_name = index_name self.index = index self._path = path - def __contains__(self, key: Any) -> bool: + def __contains__(self, key: Key) -> bool: return key in self._path - def duplicate(self, key: Any, value: Any, get_provider) -> Dict[Key, Any]: - graph = {} + def duplicate( + self, + key: Key, + value: Any, + get_provider: Callable[..., Tuple[Callable[..., Any], Dict[TypeVar, Key]]], + ) -> Graph: + graph: Graph = {} for idx in self.index: provider, args = value subkey = self.key(idx, key) @@ -138,7 +143,7 @@ def duplicate(self, key: Any, value: Any, get_provider) -> Dict[Key, Any]: graph[subkey] = (provider, args_with_index) return graph - def key(self, i: int, value_name: Union[Type[T], Item[T]]) -> Item[T]: + def key(self, i: IndexType, value_name: Union[Type[T], Item[T]]) -> Item[T]: label = Label(self._index_name, i) if isinstance(value_name, Item): return Item(value_name.label + (label,), value_name.tp) @@ -146,7 +151,7 @@ def key(self, i: int, value_name: Union[Type[T], Item[T]]) -> Item[T]: return Item((label,), value_name) -class NoGrouping(Grouper, Generic[IndexType]): +class NoGrouping(Grouper[IndexType]): """Helper for rewriting the graph to map over a given index.""" def __init__(self, param_table: ParamTable, graph_template: Graph) -> None: @@ -158,33 +163,38 @@ def __init__(self, param_table: ParamTable, graph_template: Graph) -> None: ) -class GroupBy(Grouper, Generic[IndexType, LabelType]): +class GroupBy(Grouper[LabelType], Generic[IndexType, LabelType]): """Helper for rewriting the graph to group by a given index.""" def __init__( - self, param_table: ParamTable, graph_template: Graph, label_name + self, param_table: ParamTable, graph_template: Graph, label_name: type ) -> None: self._label_name = label_name self._group_node = self._find_grouping_node(param_table.row_dim, graph_template) - index: Dict[LabelType, List[IndexType]] = defaultdict(list) + self._groups: Dict[LabelType, List[IndexType]] = defaultdict(list) for idx, label in zip(param_table.index, param_table[label_name]): - index[label].append(idx) + self._groups[label].append(idx) super().__init__( index_name=label_name, - index=index, + index=self._groups, path=_find_nodes_in_paths(graph_template, self._group_node), ) - def duplicate(self, key: Any, value: Any, get_provider) -> Dict[Key, Any]: + def duplicate( + self, + key: Key, + value: Any, + get_provider: Callable[..., Tuple[Callable[..., Any], Dict[TypeVar, Key]]], + ) -> Graph: if key != self._group_node: return super().duplicate(key, value, get_provider) - graph = {} + graph: Graph = {} provider, args = value - for index in self.index: - labels = self.index[index] + for idx in self.index: + labels = self._groups[idx] if set(labels) - set(provider.labels): raise ValueError(f'{labels} is not a subset of {provider.labels}') - subkey = self.key(index, key) + subkey = self.key(idx, key) selected = { label: arg for label, arg in zip(provider.labels, args) @@ -205,7 +215,7 @@ def _find_grouping_node(self, index_name: Key, subgraph: Graph) -> type: raise ValueError(f"Could not find unique grouping node, found {ends}") -class SeriesProvider(Generic[KeyType, ValueType]): +class SeriesProvider(Generic[KeyType]): """ Internal provider for combining results obtained based on different rows in a param table into a single object. @@ -444,7 +454,7 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: # prior call to _build_series, so instead of duplicating everything until the # param table is reached, we only duplicate until the node that is performing # the grouping. - grouper: Grouper + grouper: Grouper[KeyType] if ( label_name not in self._param_index_name and (params := self._param_tables.get(label_name)) is not None diff --git a/src/sciline/typing.py b/src/sciline/typing.py index 6ecd2729..f6a941c3 100644 --- a/src/sciline/typing.py +++ b/src/sciline/typing.py @@ -7,7 +7,7 @@ @dataclass(frozen=True) class Label: tp: type - index: int + index: Any T = TypeVar('T') From ab9b4732d08c3dc35b5cc24afef4bf1c752a9d6b Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 9 Aug 2023 11:59:52 +0200 Subject: [PATCH 80/87] Readability --- src/sciline/pipeline.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 5e117c36..0f929eff 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -131,16 +131,16 @@ def duplicate( get_provider: Callable[..., Tuple[Callable[..., Any], Dict[TypeVar, Key]]], ) -> Graph: graph: Graph = {} + provider, args = value for idx in self.index: - provider, args = value subkey = self.key(idx, key) if provider == _param_sentinel: - provider, _ = get_provider(subkey) - args = () - args_with_index = tuple( - self.key(idx, arg) if arg in self else arg for arg in args - ) - graph[subkey] = (provider, args_with_index) + graph[subkey] = (get_provider(subkey)[0], ()) + else: + graph[subkey] = ( + provider, + tuple(self.key(idx, arg) if arg in self else arg for arg in args), + ) return graph def key(self, i: IndexType, value_name: Union[Type[T], Item[T]]) -> Item[T]: @@ -191,10 +191,10 @@ def duplicate( graph: Graph = {} provider, args = value for idx in self.index: + subkey = self.key(idx, key) labels = self._groups[idx] if set(labels) - set(provider.labels): raise ValueError(f'{labels} is not a subset of {provider.labels}') - subkey = self.key(idx, key) selected = { label: arg for label, arg in zip(provider.labels, args) From c41fbd80f2bd80b7c7c594427d9bbc2f237f495d Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 9 Aug 2023 12:55:37 +0200 Subject: [PATCH 81/87] Use more common code --- src/sciline/pipeline.py | 57 ++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 0f929eff..faa7afc6 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -137,12 +137,21 @@ def duplicate( if provider == _param_sentinel: graph[subkey] = (get_provider(subkey)[0], ()) else: - graph[subkey] = ( - provider, - tuple(self.key(idx, arg) if arg in self else arg for arg in args), - ) + graph[subkey] = self._copy_node(key, provider, args, idx) return graph + def _copy_node( + self, + key: Key, + provider: Union[Provider, SeriesProvider[IndexType]], + args: Tuple[Key, ...], + idx: IndexType, + ) -> Tuple[Callable[..., Any], Tuple[Key, ...]]: + return ( + provider, + tuple(self.key(idx, arg) if arg in self else arg for arg in args), + ) + def key(self, i: IndexType, value_name: Union[Type[T], Item[T]]) -> Item[T]: label = Label(self._index_name, i) if isinstance(value_name, Item): @@ -180,29 +189,23 @@ def __init__( path=_find_nodes_in_paths(graph_template, self._group_node), ) - def duplicate( + def _copy_node( self, key: Key, - value: Any, - get_provider: Callable[..., Tuple[Callable[..., Any], Dict[TypeVar, Key]]], - ) -> Graph: - if key != self._group_node: - return super().duplicate(key, value, get_provider) - graph: Graph = {} - provider, args = value - for idx in self.index: - subkey = self.key(idx, key) - labels = self._groups[idx] - if set(labels) - set(provider.labels): - raise ValueError(f'{labels} is not a subset of {provider.labels}') - selected = { - label: arg - for label, arg in zip(provider.labels, args) - if label in labels - } - split_provider = SeriesProvider(selected, provider._row_dim) - graph[subkey] = (split_provider, tuple(selected.values())) - return graph + provider: Union[Provider, SeriesProvider[IndexType]], + args: Tuple[Key, ...], + idx: LabelType, + ) -> Tuple[Callable[..., Any], Tuple[Key, ...]]: + if (not isinstance(provider, SeriesProvider)) or key != self._group_node: + return super()._copy_node(key, provider, args, idx) + labels = self._groups[idx] + if set(labels) - set(provider.labels): + raise ValueError(f'{labels} is not a subset of {provider.labels}') + selected = { + label: arg for label, arg in zip(provider.labels, args) if label in labels + } + split_provider = SeriesProvider(selected, provider.row_dim) + return (split_provider, tuple(selected.values())) def _find_grouping_node(self, index_name: Key, subgraph: Graph) -> type: ends: List[type] = [] @@ -223,10 +226,10 @@ class SeriesProvider(Generic[KeyType]): def __init__(self, labels: Iterable[KeyType], row_dim: type) -> None: self.labels = labels - self._row_dim = row_dim + self.row_dim = row_dim def __call__(self, *vals: ValueType) -> Series[KeyType, ValueType]: - return Series(self._row_dim, dict(zip(self.labels, vals))) + return Series(self.row_dim, dict(zip(self.labels, vals))) class _param_sentinel: From 6d28ea99754b899277b35e61635c6e050d68fff8 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 9 Aug 2023 13:57:29 +0200 Subject: [PATCH 82/87] Fix compact flag --- docs/user-guide/parameter-tables.ipynb | 4 ++-- src/sciline/visualize.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/user-guide/parameter-tables.ipynb b/docs/user-guide/parameter-tables.ipynb index e28c95b8..8ed44e3b 100644 --- a/docs/user-guide/parameter-tables.ipynb +++ b/docs/user-guide/parameter-tables.ipynb @@ -464,7 +464,7 @@ ], "metadata": { "kernelspec": { - "display_name": "base", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -482,5 +482,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index 37440193..f2c90e33 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -121,7 +121,7 @@ def with_labels(base: str) -> Node: if labels: return Node( name=f'{base}({", ".join([format_label(l) for l in labels])})', - collapsed=True, + collapsed=compact, ) return Node(name=base) From 23f7f37205937cfe85d210c2ad0179f69452e593 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 10 Aug 2023 10:40:41 +0200 Subject: [PATCH 83/87] Use alias --- src/sciline/pipeline.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index faa7afc6..cba38ea4 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -128,7 +128,7 @@ def duplicate( self, key: Key, value: Any, - get_provider: Callable[..., Tuple[Callable[..., Any], Dict[TypeVar, Key]]], + get_provider: Callable[..., Tuple[Provider, Dict[TypeVar, Key]]], ) -> Graph: graph: Graph = {} provider, args = value @@ -146,7 +146,7 @@ def _copy_node( provider: Union[Provider, SeriesProvider[IndexType]], args: Tuple[Key, ...], idx: IndexType, - ) -> Tuple[Callable[..., Any], Tuple[Key, ...]]: + ) -> Tuple[Provider, Tuple[Key, ...]]: return ( provider, tuple(self.key(idx, arg) if arg in self else arg for arg in args), @@ -195,7 +195,7 @@ def _copy_node( provider: Union[Provider, SeriesProvider[IndexType]], args: Tuple[Key, ...], idx: LabelType, - ) -> Tuple[Callable[..., Any], Tuple[Key, ...]]: + ) -> Tuple[Provider, Tuple[Key, ...]]: if (not isinstance(provider, SeriesProvider)) or key != self._group_node: return super()._copy_node(key, provider, args, idx) labels = self._groups[idx] From cc343fa919c1902e1c6bf649dae8c74a91936f2d Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 10 Aug 2023 10:57:38 +0200 Subject: [PATCH 84/87] Illustrate complicated algorithm --- src/sciline/pipeline.py | 40 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index cba38ea4..0cd4bd6a 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -161,7 +161,20 @@ def key(self, i: IndexType, value_name: Union[Type[T], Item[T]]) -> Item[T]: class NoGrouping(Grouper[IndexType]): - """Helper for rewriting the graph to map over a given index.""" + r""" + Helper for rewriting the graph to map over a given index. + + Given a graph template, this makes a transformation as follows: + + S P[0] P[1] P[2] + | | | | + A -> A[0] A[1] A[2] + | | | | + B B[0] B[1] B[2] + + Where S is a sentinel value, P are parameters from a parameter table, 0,1,2 + are indices of the param table rows, and A and B are arbitrary nodes in the graph. + """ def __init__(self, param_table: ParamTable, graph_template: Graph) -> None: index_name = param_table.row_dim @@ -173,7 +186,27 @@ def __init__(self, param_table: ParamTable, graph_template: Graph) -> None: class GroupBy(Grouper[LabelType], Generic[IndexType, LabelType]): - """Helper for rewriting the graph to group by a given index.""" + r""" + Helper for rewriting the graph to group by a given index. + + Given a graph template, this makes a transformation as follows: + + P[0] P[1] P[2] P[0] P[1] P[2] + | | | | | | + A[0] A[1] A[2] A[0] A[1] A[2] + | | | | | | + B[0] B[1] B[2] -> B[0] B[1] B[2] + \______|______/ \______/ | + | | | + C C[x] C[y] + | | | + D D[x] D[y] + + Here, the upper half of the graph originates from a prior transformation of + a graph template using `NoGrouping`. The output of this combined with further + nodes is the graph template passes to this class. x and y are the labels used + in a grouping operation. + """ def __init__( self, param_table: ParamTable, graph_template: Graph, label_name: type @@ -456,7 +489,8 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: # ungrouped graph (including duplication of nodes) will have been built by a # prior call to _build_series, so instead of duplicating everything until the # param table is reached, we only duplicate until the node that is performing - # the grouping. + # the grouping. See the docstrings of GroupBy and NoGrouping for an + # illustration. grouper: Grouper[KeyType] if ( label_name not in self._param_index_name From 26856e0cd4f261eef33ca046e89354b152efe426 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 11 Aug 2023 07:54:11 +0200 Subject: [PATCH 85/87] Improve docstrings --- src/sciline/pipeline.py | 146 +++++++++++++++--------- tests/pipeline_with_param_table_test.py | 13 +++ 2 files changed, 108 insertions(+), 51 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 0cd4bd6a..9f8ebe11 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -115,7 +115,7 @@ def _find_nodes_in_paths( return list(nodes) -class Grouper(Generic[IndexType]): +class ReplicatorBase(Generic[IndexType]): def __init__(self, index_name: type, index: Iterable[IndexType], path: List[Key]): self._index_name = index_name self.index = index @@ -124,7 +124,7 @@ def __init__(self, index_name: type, index: Iterable[IndexType], path: List[Key] def __contains__(self, key: Key) -> bool: return key in self._path - def duplicate( + def replicate( self, key: Key, value: Any, @@ -160,19 +160,20 @@ def key(self, i: IndexType, value_name: Union[Type[T], Item[T]]) -> Item[T]: return Item((label,), value_name) -class NoGrouping(Grouper[IndexType]): +class Replicator(ReplicatorBase[IndexType]): r""" Helper for rewriting the graph to map over a given index. - Given a graph template, this makes a transformation as follows: + See Pipeline._build_series for context. Given a graph template, this makes a + transformation as follows: - S P[0] P[1] P[2] + S P1[0] P1[1] P1[2] | | | | A -> A[0] A[1] A[2] | | | | B B[0] B[1] B[2] - Where S is a sentinel value, P are parameters from a parameter table, 0,1,2 + Where S is a sentinel value, P1 are parameters from a parameter table, 0,1,2 are indices of the param table rows, and A and B are arbitrary nodes in the graph. """ @@ -185,27 +186,29 @@ def __init__(self, param_table: ParamTable, graph_template: Graph) -> None: ) -class GroupBy(Grouper[LabelType], Generic[IndexType, LabelType]): +class GroupingReplicator(ReplicatorBase[LabelType], Generic[IndexType, LabelType]): r""" Helper for rewriting the graph to group by a given index. - Given a graph template, this makes a transformation as follows: + See Pipeline._build_series for context. Given a graph template, this makes a + transformation as follows: - P[0] P[1] P[2] P[0] P[1] P[2] + P1[0] P1[1] P1[2] P1[0] P1[1] P1[2] | | | | | | A[0] A[1] A[2] A[0] A[1] A[2] | | | | | | B[0] B[1] B[2] -> B[0] B[1] B[2] \______|______/ \______/ | | | | - C C[x] C[y] + SB SB[x] SB[y] | | | - D D[x] D[y] + C C[x] C[y] - Here, the upper half of the graph originates from a prior transformation of - a graph template using `NoGrouping`. The output of this combined with further - nodes is the graph template passes to this class. x and y are the labels used - in a grouping operation. + Where SB is Series[Idx,B]. Here, the upper half of the graph originates from a + prior transformation of a graph template using `Replicator`. The output of this + combined with further nodes is the graph template passed to this class. x and y + are the labels used in a grouping operation, based on the values of a ParamTable + column P2. """ def __init__( @@ -292,7 +295,7 @@ def __init__( self._providers: Dict[Key, Provider] = {} self._subproviders: Dict[type, Dict[Tuple[Key | TypeVar, ...], Provider]] = {} self._param_tables: Dict[Key, ParamTable] = {} - self._param_index_name: Dict[Key, Key] = {} + self._param_name_to_table_key: Dict[Key, Key] = {} for provider in providers or []: self.insert(provider) for tp, param in (params or {}).items(): @@ -368,11 +371,11 @@ def set_param_table(self, params: ParamTable) -> None: if params.row_dim in self._param_tables: raise ValueError(f'Parameter table for {params.row_dim} already set') for param_name in params: - if param_name in self._param_index_name: + if param_name in self._param_name_to_table_key: raise ValueError(f'Parameter {param_name} already set') self._param_tables[params.row_dim] = params for param_name in params: - self._param_index_name[param_name] = params.row_dim + self._param_name_to_table_key[param_name] = params.row_dim for param_name, values in params.items(): for index, label in zip(params.index, values): self._set_provider( @@ -452,8 +455,8 @@ def build( stack: List[Union[Type[T], Item[T]]] = [tp] while stack: tp = stack.pop() - if search_param_tables and tp in self._param_index_name: - graph[tp] = (_param_sentinel, (self._param_index_name[tp],)) + if search_param_tables and tp in self._param_name_to_table_key: + graph[tp] = (_param_sentinel, (self._param_name_to_table_key[tp],)) continue if get_origin(tp) == Series: graph.update(self._build_series(tp)) # type: ignore[arg-type] @@ -473,50 +476,91 @@ def build( return graph def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: - label_name: Type[KeyType] + """ + Build (sub)graph for a Series type implementing ParamTable-based functionality. + + We illustrate this with an example. Given a ParamTable with row_dim 'Idx': + + Idx | P1 | P2 + 0 | a | x + 1 | b | x + 2 | c | y + + and providers for A depending on P1 and B depending on A. Calling + build(Series[Idx,B]) will call _build_series(Series[Idx,B]). This results in + the following procedure here: + + 1. Call build(P1), resulting, e.g., in a graph S->A->B, where S is a sentinel. + The sentinel is used because build() cannot find a unique P1, since it is + not a single value but a column in a table. + 2. Instantiation of `Replicator`, which will be used to replicate the + relevant parts of the graph (see illustration there). + 3. Insert a special `SeriesProvider` node, which will gather the duplicates of + the 'B' node and providers the requested Series[Idx,B]. + 4. Replicate the graph. Nodes that do not directly or indirectly depend on P1 + are not replicated. + + Conceptually, the final result will be { + 0: B(A(a)), + 1: B(A(b)), + 2: B(A(c)) + }. + + In more complex cases, we may be dealing with multiple levels of Series, + which is used for grouping operations. Consider the above example, but with + and additional provider for C depending on Series[Idx,B]. Calling + build(Series[P2,C]) will call _build_series(Series[P2,C]). This results in + the following procedure here: + + a. Call build(C), which results in the procedure above, i.e., a nested call + to _build_series(Series[Idx,B]) and the resulting graph as explained above. + b. Instantiation of `GroupingReplicator`, which will be used to replicate the + relevant parts of the graph (see illustration there). + c. Insert a special `SeriesProvider` node, which will gather the duplicates of + the 'C' node and providers the requested Series[P2,C]. + c. Replicate the graph. Nodes that do not directly or indirectly depend on + the special `SeriesProvider` node (from step 3.) are not replicated. + + Conceptually, the final result will be { + x: C({ + 0: B(A(a)), + 1: B(A(b)) + }), + y: C({ + 2: B(A(c)) + }) + }. + """ + index_name: Type[KeyType] value_type: Type[ValueType] - label_name, value_type = get_args(tp) - # Step 1: - # Build a graph that can compute the value type. As we are building - # a Series, this will terminate when it reaches a parameter that is not a - # single provided value but a collection of values from a parameter table - # column. Instead of single value (which does not exist), a sentinel is - # used to mark this, for processing below. + index_name, value_type = get_args(tp) + subgraph = self.build(value_type, search_param_tables=True) - # Step 2: - # Identify nodes in the graph that need to be duplicated as they lie in the - # path to a parameter from a table. In the case of grouping, note that the - # ungrouped graph (including duplication of nodes) will have been built by a - # prior call to _build_series, so instead of duplicating everything until the - # param table is reached, we only duplicate until the node that is performing - # the grouping. See the docstrings of GroupBy and NoGrouping for an - # illustration. - grouper: Grouper[KeyType] + + replicator: ReplicatorBase[KeyType] if ( - label_name not in self._param_index_name - and (params := self._param_tables.get(label_name)) is not None + index_name not in self._param_name_to_table_key + and (params := self._param_tables.get(index_name)) is not None ): - grouper = NoGrouping(param_table=params, graph_template=subgraph) - elif (index_name := self._param_index_name.get(label_name)) is not None: - grouper = GroupBy( - param_table=self._param_tables[index_name], + replicator = Replicator(param_table=params, graph_template=subgraph) + elif (table_key := self._param_name_to_table_key.get(index_name)) is not None: + replicator = GroupingReplicator( + param_table=self._param_tables[table_key], graph_template=subgraph, - label_name=label_name, + label_name=index_name, ) else: - raise KeyError(f'No parameter table found for label {label_name}') + raise KeyError(f'No parameter table found for label {index_name}') graph: Graph = {} graph[tp] = ( - SeriesProvider(list(grouper.index), label_name), - tuple(grouper.key(idx, value_type) for idx in grouper.index), + SeriesProvider(list(replicator.index), index_name), + tuple(replicator.key(idx, value_type) for idx in replicator.index), ) - # Step 3: - # Duplicate nodes, replacing keys with indexed keys. for key, value in subgraph.items(): - if key in grouper: - graph.update(grouper.duplicate(key, value, self._get_provider)) + if key in replicator: + graph.update(replicator.replicate(key, value, self._get_provider)) else: graph[key] = value return graph diff --git a/tests/pipeline_with_param_table_test.py b/tests/pipeline_with_param_table_test.py index fc4c6968..dd847308 100644 --- a/tests/pipeline_with_param_table_test.py +++ b/tests/pipeline_with_param_table_test.py @@ -290,6 +290,19 @@ def test_groupby_by_requesting_series_of_series_preserves_indices() -> None: ) +def test_can_groupby_by_param_used_in_ancestor() -> None: + Row = NewType("Row", int) + Param = NewType("Param1", str) + + pl = sl.Pipeline() + pl.set_param_table(sl.ParamTable(Row, {Param: ['x', 'x', 'y']})) + expected = sl.Series( + Param, + {"x": sl.Series(Row, {0: "x", 1: "x"}), "y": sl.Series(Row, {2: "y"})}, + ) + assert pl.compute(sl.Series[Param, sl.Series[Row, Param]]) == expected + + def test_multi_level_groupby_raises_with_params_from_same_table() -> None: Row = NewType("Row", int) Param1 = NewType("Param1", int) From 0a51b7038941edec189de3392f7d88011ebf99da Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 11 Aug 2023 08:07:14 +0200 Subject: [PATCH 86/87] typo --- tests/pipeline_with_param_table_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipeline_with_param_table_test.py b/tests/pipeline_with_param_table_test.py index dd847308..88f8a718 100644 --- a/tests/pipeline_with_param_table_test.py +++ b/tests/pipeline_with_param_table_test.py @@ -292,7 +292,7 @@ def test_groupby_by_requesting_series_of_series_preserves_indices() -> None: def test_can_groupby_by_param_used_in_ancestor() -> None: Row = NewType("Row", int) - Param = NewType("Param1", str) + Param = NewType("Param", str) pl = sl.Pipeline() pl.set_param_table(sl.ParamTable(Row, {Param: ['x', 'x', 'y']})) From cc9809269c9c423213592ecc9c06f5c66ef6e9e1 Mon Sep 17 00:00:00 2001 From: Simon Heybrock <12912489+SimonHeybrock@users.noreply.github.com> Date: Fri, 11 Aug 2023 11:38:18 +0200 Subject: [PATCH 87/87] Update src/sciline/pipeline.py Co-authored-by: Jan-Lukas Wynen --- src/sciline/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 9f8ebe11..75d8a527 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -496,7 +496,7 @@ def _build_series(self, tp: Type[Series[KeyType, ValueType]]) -> Graph: 2. Instantiation of `Replicator`, which will be used to replicate the relevant parts of the graph (see illustration there). 3. Insert a special `SeriesProvider` node, which will gather the duplicates of - the 'B' node and providers the requested Series[Idx,B]. + the 'B' node and provides the requested Series[Idx,B]. 4. Replicate the graph. Nodes that do not directly or indirectly depend on P1 are not replicated.