Skip to content

Commit

Permalink
Json: split function and data nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
jl-wynen committed Mar 11, 2024
1 parent 9b58c9f commit 5c6e6de
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 95 deletions.
81 changes: 49 additions & 32 deletions src/sciline/serialize/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import importlib.resources
import json
from typing import Union
from typing import Any, Union

from .._provider import Provider
from .._utils import key_full_qualname, key_name, provider_full_qualname, provider_name
Expand Down Expand Up @@ -48,7 +48,7 @@ def json_serialize_task_graph(graph: Graph) -> dict[str, Json]:
edges = []
for key, provider in graph.items():
n, e = _serialize_provider(key, provider, id_gen)
nodes.append(n)
nodes.extend(n)
edges.extend(e)
return {
'directed': True,
Expand All @@ -58,9 +58,11 @@ def json_serialize_task_graph(graph: Graph) -> dict[str, Json]:
}


# returns tuple[list[node], list[edge]]
# where the nodes are either a single data node or a data node and a function node
def _serialize_provider(
key: Key, provider: Provider, id_gen: _IdGenerator
) -> tuple[dict[str, Json], list[dict[str, Json]]]:
) -> tuple[list[dict[str, Json]], list[dict[str, Json]]]:
if provider.kind == 'function':
return _serialize_function(key, provider, id_gen)
if provider.kind == 'parameter':
Expand All @@ -72,75 +74,90 @@ def _serialize_provider(

def _serialize_param(
key: Key, id_gen: _IdGenerator
) -> tuple[dict[str, Json], list[dict[str, Json]]]:
) -> tuple[list[dict[str, Json]], list[dict[str, Json]]]:
node = {
'id': id_gen.node_id(key),
'kind': 'parameter',
'id': id_gen.data_node_id(key),
'kind': 'data',
'label': key_name(key),
'out': key_full_qualname(key),
'type': key_full_qualname(key),
}
return node, [] # type: ignore[return-value]
return [node], [] # type: ignore[return-value]


def _serialize_function(
key: Key, provider: Provider, id_gen: _IdGenerator
) -> tuple[dict[str, Json], list[dict[str, Json]]]:
node_id = id_gen.node_id(key)

) -> tuple[list[dict[str, Json]], list[dict[str, Json]]]:
edges = []
args = []
kwargs = {}
for i, arg in enumerate(provider.arg_spec.args):
edge = _serialize_edge(arg, key, i, id_gen)
edge = _serialize_edge_data_to_fn(arg, key, i, id_gen)
edges.append(edge)
args.append(edge['id'])
for name, kwarg in provider.arg_spec.kwargs:
edge = _serialize_edge(kwarg, key, name, id_gen)
edge = _serialize_edge_data_to_fn(kwarg, key, name, id_gen)
edges.append(edge)
kwargs[name] = edge['id']
edges.append(_serialize_edge_fn_to_data(key, key, id_gen))

node = {
'id': node_id,
fn_node = {
'id': id_gen.function_node_id(key),
'kind': 'function',
'label': provider_name(provider),
'function': provider_full_qualname(provider),
'out': key_full_qualname(key),
'args': args,
'kwargs': kwargs,
}
[data_node], _ = _serialize_param(key, id_gen)

return node, edges # type: ignore[return-value]
return [fn_node, data_node], edges # type: ignore[return-value]


def _serialize_edge(
def _serialize_edge_data_to_fn(
source: Key, target: Key, arg: Union[int, str], id_gen: _IdGenerator
) -> dict[str, str]:
return {
'id': id_gen.edge_id(arg, target),
'source': id_gen.node_id(source),
'target': id_gen.node_id(target),
'id': id_gen.edge_id(source, target, arg),
'source': id_gen.data_node_id(source),
'target': id_gen.function_node_id(target),
}


def _serialize_edge_fn_to_data(
source: Key, target: Key, id_gen: _IdGenerator
) -> dict[str, str]:
return {
'id': id_gen.edge_id(source, target, 0),
'source': id_gen.function_node_id(source),
'target': id_gen.data_node_id(target),
}


class _IdGenerator:
def __init__(self) -> None:
self._assigned: dict[int, str] = {}
self._assigned: dict[Any, str] = {}
self._next = 0

def node_id(self, key: Key) -> str:
def data_node_id(self, key: Key) -> str:
# Keys must be unique and are required to be hashable to construct TaskGraph.
return self._get_or_insert(key)

def function_node_id(self, key: Key) -> str:
# Use the key instead of the provider to avoid problems with
# unhashable providers, keys need to be hashable to construct TaskGraph.
return self._get_or_insert(hash(key))
# unhashable providers.
# Use tuple with 'fn' to disambiguate from data nodes.
return self._get_or_insert(('fn', key))

def edge_id(self, arg: Union[int, str], target: Key) -> str:
# Uses the arg number or kwarg name instead of the source key
def edge_id(self, source: Key, target: Key, arg: Union[int, str]) -> str:
# Uses the arg number or kwarg name
# to disambiguate arguments with the same type.
return self._get_or_insert(hash((arg, target)))
return self._get_or_insert((source, target, arg))

def _get_or_insert(self, hsh: int) -> str:
def _get_or_insert(self, hashable: Any) -> str:
try:
return self._assigned[hsh]
return self._assigned[hashable]
except KeyError:
self._assigned[hsh] = str(self._next)
id_ = str(self._next)
self._assigned[hashable] = id_
self._next += 1
return self._assigned[hsh]
return id_
29 changes: 22 additions & 7 deletions src/sciline/serialize/graph_json_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,15 @@
"type": "string",
"enum": [
"function",
"parameter"
"data"
],
"description": "Indicates what kind of object the node represents. Determines the schema of the node."
},
"out": {
"type": "string",
"description": "Fully qualified name of the type of object produced by this node."
}
},
"required": [
"id",
"label",
"kind",
"out"
"kind"
],
"allOf": [
{
Expand Down Expand Up @@ -82,6 +77,26 @@
"kwargs"
]
}
},
{
"if": {
"properties": {
"kind": {
"const": "data"
}
}
},
"then": {
"properties": {
"type": {
"type": "string",
"description": "Fully qualified name of the Python (domain) type representing this node."
}
},
"required": [
"type"
]
}
}
]
}
Expand Down
Loading

0 comments on commit 5c6e6de

Please sign in to comment.