Skip to content

Commit

Permalink
Merge pull request #51 from scipp/large-graph-readability
Browse files Browse the repository at this point in the history
Cluster generic products when visualizing graph
  • Loading branch information
SimonHeybrock authored Sep 7, 2023
2 parents dea2e84 + c50f1cd commit a09740f
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 24 deletions.
15 changes: 3 additions & 12 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .param_table import ParamTable
from .scheduler import Scheduler
from .series import Series
from .typing import Graph, Item, Key, Label, Provider
from .typing import Graph, Item, Key, Label, Provider, get_optional

T = TypeVar('T')
KeyType = TypeVar('KeyType')
Expand Down Expand Up @@ -116,15 +116,6 @@ def _find_nodes_in_paths(
return list(nodes)


def _get_optional(tp: Key) -> Optional[Any]:
if get_origin(tp) != Union:
return None
args = get_args(tp)
if len(args) != 2 or type(None) not in args:
return None
return args[0] if args[1] == type(None) else args[1] # noqa: E721


def provide_none() -> None:
return None

Expand Down Expand Up @@ -173,7 +164,7 @@ def _copy_node(
)

def key(self, i: IndexType, value_name: Union[Type[T], Item[T]]) -> Item[T]:
value_name = _get_optional(value_name) or value_name
value_name = get_optional(value_name) or value_name
label = Label(self._index_name, i)
if isinstance(value_name, Item):
return Item(value_name.label + (label,), value_name.tp)
Expand Down Expand Up @@ -488,7 +479,7 @@ def build(
if get_origin(tp) == Series:
graph.update(self._build_series(tp)) # type: ignore[arg-type]
continue
if (optional_arg := _get_optional(tp)) is not None:
if (optional_arg := get_optional(tp)) is not None:
try:
optional_subgraph = self.build(
optional_arg, search_param_tables=search_param_tables
Expand Down
23 changes: 22 additions & 1 deletion src/sciline/typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from dataclasses import dataclass
from typing import Any, Callable, Dict, Generic, Tuple, Type, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
Generic,
Optional,
Tuple,
Type,
TypeVar,
Union,
get_args,
get_origin,
)


@dataclass(frozen=True)
Expand All @@ -24,3 +36,12 @@ class Item(Generic[T]):

Key = Union[type, Item]
Graph = Dict[Key, Tuple[Provider, Tuple[Key, ...]]]


def get_optional(tp: Key) -> Optional[Any]:
if get_origin(tp) != Union:
return None
args = get_args(tp)
if len(args) != 2 or type(None) not in args:
return None
return args[0] if args[1] == type(None) else args[1] # noqa: E721
70 changes: 60 additions & 10 deletions src/sciline/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Expand All @@ -17,7 +18,7 @@
from graphviz import Digraph

from .pipeline import Pipeline, SeriesProvider
from .typing import Graph, Item, Key
from .typing import Graph, Item, Key, get_optional


@dataclass
Expand All @@ -26,7 +27,16 @@ class Node:
collapsed: bool = False


def to_graphviz(graph: Graph, compact: bool = False, **kwargs: Any) -> Digraph:
FormattedGraph = Dict[str, Tuple[str, List[Node], Node]]


def to_graphviz(
graph: Graph,
compact: bool = False,
cluster_generics: bool = True,
cluster_color: Optional[str] = '#f0f0ff',
**kwargs: Any,
) -> Digraph:
"""
Convert output of :py:class:`sciline.Pipeline.get_graph` to a graphviz graph.
Expand All @@ -36,13 +46,53 @@ def to_graphviz(graph: Graph, compact: bool = False, **kwargs: Any) -> Digraph:
Output of :py:class:`sciline.Pipeline.get_graph`.
compact:
If True, parameter-table-dependent branches are collapsed into a single copy
of the branch. Recommendend for large graphs with long parameter tables.
of the branch. Recommended for large graphs with long parameter tables.
cluster_generics:
If True, generic products are grouped into clusters.
cluster_color:
Background color of clusters. If None, clusters are dotted.
kwargs:
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
dot = Digraph(strict=True, **kwargs)
for p, (p_name, args, ret) in _format_graph(graph, compact=compact).items():
dot.node(ret.name, ret.name, shape='box3d' if ret.collapsed else 'rectangle')
formatted_graph = _format_graph(graph, compact=compact)
ordered_graph = dict(
sorted(formatted_graph.items(), key=lambda item: item[1][2].name)
)
subgraphs = _to_subgraphs(ordered_graph)

for origin, subgraph in subgraphs.items():
cluster = cluster_generics and len(subgraph) > 1
name = f'cluster_{origin}' if cluster else None
with dot.subgraph(name=name) as dot_subgraph:
if cluster:
dot_subgraph.attr(rank='same')
if cluster_color is None:
dot_subgraph.attr(style='dotted')
else:
dot_subgraph.attr(style='filled', color=cluster_color)
_add_subgraph(subgraph, dot, dot_subgraph)
return dot


def _to_subgraphs(graph: FormattedGraph) -> Dict[str, FormattedGraph]:
def get_subgraph_name(name: str) -> str:
return name.split('[')[0]

subgraphs: Dict[str, FormattedGraph] = {}
for p, (p_name, args, ret) in graph.items():
subgraph_name = get_subgraph_name(ret.name)
if subgraph_name not in subgraphs:
subgraphs[subgraph_name] = {}
subgraphs[subgraph_name][p] = (p_name, args, ret)
return subgraphs


def _add_subgraph(graph: FormattedGraph, dot: Digraph, subgraph: Digraph) -> None:
for p, (p_name, args, ret) in graph.items():
subgraph.node(
ret.name, ret.name, shape='box3d' if ret.collapsed else 'rectangle'
)
# Do not draw dummy providers created by Pipeline when setting instances
if p_name in (
f'{_qualname(Pipeline.__setitem__)}.<locals>.<lambda>',
Expand All @@ -59,7 +109,6 @@ def to_graphviz(graph: Graph, compact: bool = False, **kwargs: Any) -> Digraph:
for arg in args:
dot.edge(arg.name, p)
dot.edge(p, ret.name)
return dot


def _qualname(obj: Any) -> Any:
Expand All @@ -68,9 +117,7 @@ def _qualname(obj: Any) -> Any:
)


def _format_graph(
graph: Graph, compact: bool
) -> Dict[str, Tuple[str, List[Node], Node]]:
def _format_graph(graph: Graph, compact: bool) -> FormattedGraph:
return {
_format_provider(provider, ret, compact=compact): (
_qualname(provider),
Expand Down Expand Up @@ -108,7 +155,10 @@ def _format_type(tp: Key, compact: bool = False) -> Node:

tp, labels = _extract_type_and_labels(tp, compact=compact)

def get_base(tp: type) -> str:
if (tp_ := get_optional(tp)) is not None:
tp = tp_

def get_base(tp: Key) -> str:
return tp.__name__ if hasattr(tp, '__name__') else str(tp).split('.')[-1]

def format_label(label: Union[type, Tuple[type, Any]]) -> str:
Expand Down
7 changes: 6 additions & 1 deletion tests/visualize_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from typing import Generic, TypeVar
from typing import Generic, Optional, TypeVar

import sciline as sl
from sciline.visualize import to_graphviz
Expand Down Expand Up @@ -33,3 +33,8 @@ class SubA(A[T]):
assert sl.visualize._format_type(A[float]).name == 'A[float]'
assert sl.visualize._format_type(SubA[float]).name == 'SubA[float]'
assert sl.visualize._format_type(B[float]).name == 'B[float]'


def test_optional_types_formatted_as_their_content() -> None:
formatted = sl.visualize._format_type(Optional[float]) # type: ignore[arg-type]
assert formatted.name == 'float'

0 comments on commit a09740f

Please sign in to comment.