Skip to content

Commit

Permalink
Fix PythonNode when used as return. (#446)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasraabe authored Oct 11, 2023
1 parent 7b6f184 commit 0fceb2a
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 63 deletions.
1 change: 1 addition & 0 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
- {pull}`443` ensures that `PythonNode.name` is always unique by only handling it
internally.
- {pull}`444` moves all content of `setup.cfg` to `pyproject.toml`.
- {pull}`446` refactors `create_name_of_python_node` and fixes `PythonNode`s as returns.
- {pull}`447` fixes handling multiple product annotations of a task.

## 0.4.0 - 2023-10-07
Expand Down
22 changes: 4 additions & 18 deletions src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Iterable
from typing import TYPE_CHECKING

from _pytask.collect_utils import create_name_of_python_node
from _pytask.collect_utils import parse_dependencies_from_task_function
from _pytask.collect_utils import parse_products_from_task_function
from _pytask.config import hookimpl
Expand Down Expand Up @@ -305,7 +306,7 @@ def pytask_collect_task(

@hookimpl(trylast=True)
def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> PNode:
"""Collect a node of a task as a :class:`pytask.nodes.PathNode`.
"""Collect a node of a task as a :class:`pytask.PNode`.
Strings are assumed to be paths. This might be a strict assumption, but since this
hook is executed at last and possible errors will be shown, it seems reasonable and
Expand All @@ -325,8 +326,7 @@ def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> PN
node = node_info.value

if isinstance(node, PythonNode):
node_name = _create_name_of_python_node(node_info)
node.name = node_name
node.name = create_name_of_python_node(node_info)
return node

if isinstance(node, PPathNode) and not node.path.is_absolute():
Expand Down Expand Up @@ -354,7 +354,7 @@ def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> PN
)
return PathNode.from_path(node)

node_name = _create_name_of_python_node(node_info)
node_name = create_name_of_python_node(node_info)
return PythonNode(value=node, name=node_name)


Expand Down Expand Up @@ -494,17 +494,3 @@ def pytask_collect_log(
)

raise CollectionError


def _create_name_of_python_node(node_info: NodeInfo) -> str:
"""Create name of PythonNode."""
prefix = (
node_info.task_path.as_posix() + "::" + node_info.task_name
if node_info.task_path
else node_info.task_name
)
node_name = prefix + "::" + node_info.arg_name
if node_info.path:
suffix = "-".join(map(str, node_info.path))
node_name += "::" + suffix
return node_name
36 changes: 33 additions & 3 deletions src/_pytask/collect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Iterable
from typing import TYPE_CHECKING

import attrs
from _pytask._inspect import get_annotations
from _pytask.exceptions import NodeNotCollectedError
from _pytask.mark_utils import has_mark
Expand All @@ -24,6 +25,7 @@
from _pytask.tree_util import tree_leaves
from _pytask.tree_util import tree_map
from _pytask.tree_util import tree_map_with_path
from _pytask.typing import no_default
from _pytask.typing import ProductType
from attrs import define
from attrs import field
Expand Down Expand Up @@ -327,9 +329,15 @@ def parse_dependencies_from_task_function(
isinstance(x, PythonNode) and not x.hash for x in tree_leaves(nodes)
)
if not isinstance(nodes, PNode) and are_all_nodes_python_nodes_without_hash:
prefix = task_path.as_posix() + "::" + task_name if task_path else task_name
node_name = prefix + "::" + parameter_name

node_name = create_name_of_python_node(
NodeInfo(
arg_name=parameter_name,
path=(),
value=value,
task_path=task_path,
task_name=task_name,
)
)
dependencies[parameter_name] = PythonNode(value=value, name=node_name)
else:
dependencies[parameter_name] = nodes
Expand Down Expand Up @@ -606,6 +614,13 @@ def _collect_dependency(
"""
node = node_info.value

if isinstance(node, PythonNode) and node.value is no_default:
# If a node is a dependency and its value is not set, the node is a product in
# another task and the value will be set there. Thus, we wrap the original node
# in another node to retrieve the value after it is set.
new_node = attrs.evolve(node, value=node)
node_info = node_info._replace(value=new_node)

collected_node = session.hook.pytask_collect_node(
session=session, path=path, node_info=node_info
)
Expand Down Expand Up @@ -653,10 +668,25 @@ def _collect_product(
collected_node = session.hook.pytask_collect_node(
session=session, path=path, node_info=node_info
)

if collected_node is None:
msg = (
f"{node!r} can't be parsed as a product for task {task_name!r} in {path!r}."
)
raise NodeNotCollectedError(msg)

return collected_node


def create_name_of_python_node(node_info: NodeInfo) -> str:
"""Create name of PythonNode."""
prefix = (
node_info.task_path.as_posix() + "::" + node_info.task_name
if node_info.task_path
else node_info.task_name
)
node_name = prefix + "::" + node_info.arg_name
if node_info.path:
suffix = "-".join(map(str, node_info.path))
node_name += "::" + suffix
return node_name
28 changes: 15 additions & 13 deletions src/_pytask/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from _pytask.node_protocols import PPathNode
from _pytask.node_protocols import PTask
from _pytask.node_protocols import PTaskWithPath
from _pytask.path import find_common_ancestor_of_nodes
from _pytask.nodes import PythonNode
from _pytask.report import DagReport
from _pytask.shared import reduce_names_of_multiple_nodes
from _pytask.shared import reduce_node_name
Expand Down Expand Up @@ -87,6 +87,16 @@ def pytask_dag_create_dag(tasks: list[PTask]) -> nx.DiGraph:
tree_map(lambda x: dag.add_node(x.name, node=x), task.produces)
tree_map(lambda x: dag.add_edge(task.name, x.name), task.produces)

# If a node is a PythonNode wrapped in another PythonNode, it is a product from
# another task that is a dependency in the current task. Thus, draw an edge
# connecting the two nodes.
tree_map(
lambda x: dag.add_edge(x.value.name, x.name)
if isinstance(x, PythonNode) and isinstance(x.value, PythonNode)
else None,
task.depends_on,
)

_check_if_dag_has_cycles(dag)

return dag
Expand Down Expand Up @@ -114,7 +124,7 @@ def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None:
def pytask_dag_validate_dag(session: Session, dag: nx.DiGraph) -> None:
"""Validate the DAG."""
_check_if_root_nodes_are_available(dag, session.config["paths"])
_check_if_tasks_have_the_same_products(dag)
_check_if_tasks_have_the_same_products(dag, session.config["paths"])


def _have_task_or_neighbors_changed(
Expand Down Expand Up @@ -292,7 +302,7 @@ def _format_dictionary_to_tree(dict_: dict[str, list[str]], title: str) -> str:
return render_to_string(tree, console=console, strip_styles=True)


def _check_if_tasks_have_the_same_products(dag: nx.DiGraph) -> None:
def _check_if_tasks_have_the_same_products(dag: nx.DiGraph, paths: list[Path]) -> None:
nodes_created_by_multiple_tasks = []

for node in dag.nodes:
Expand All @@ -303,19 +313,11 @@ def _check_if_tasks_have_the_same_products(dag: nx.DiGraph) -> None:
nodes_created_by_multiple_tasks.append(node)

if nodes_created_by_multiple_tasks:
all_names = nodes_created_by_multiple_tasks + [
predecessor
for node in nodes_created_by_multiple_tasks
for predecessor in dag.predecessors(node)
]
common_ancestor = find_common_ancestor_of_nodes(*all_names)
dictionary = {}
for node in nodes_created_by_multiple_tasks:
short_node_name = reduce_node_name(
dag.nodes[node]["node"], [common_ancestor]
)
short_node_name = reduce_node_name(dag.nodes[node]["node"], paths)
short_predecessors = reduce_names_of_multiple_nodes(
dag.predecessors(node), dag, [common_ancestor]
dag.predecessors(node), dag, paths
)
dictionary[short_node_name] = short_predecessors
text = _format_dictionary_to_tree(dictionary, "Products from multiple tasks:")
Expand Down
29 changes: 16 additions & 13 deletions src/_pytask/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from _pytask.node_protocols import PPathNode
from _pytask.node_protocols import PTask
from _pytask.node_protocols import PTaskWithPath
from _pytask.typing import no_default
from _pytask.typing import NoDefault
from attrs import define
from attrs import field

Expand Down Expand Up @@ -47,9 +49,7 @@ class TaskWithoutPath(PTask):
A list of markers attached to the task function.
report_sections
Reports with entries for when, what, and content.
Attributes
----------
attributes: dict[Any, Any]
A dictionary to store additional information of the task.
"""

Expand Down Expand Up @@ -79,6 +79,8 @@ def execute(self, **kwargs: Any) -> None:
class Task(PTaskWithPath):
"""The class for tasks which are Python functions.
Attributes
----------
base_name
The base name of the task.
path
Expand All @@ -97,9 +99,7 @@ class Task(PTaskWithPath):
A list of markers attached to the task function.
report_sections
Reports with entries for when, what, and content.
Attributes
----------
attributes: dict[Any, Any]
A dictionary to store additional information of the task.
"""
Expand Down Expand Up @@ -204,11 +204,13 @@ class PythonNode(PNode):
"""

name: str = ""
value: Any = None
value: Any | NoDefault = no_default
hash: bool | Callable[[Any], bool] = False # noqa: A003

def load(self) -> Any:
"""Load the value."""
if isinstance(self.value, PythonNode):
return self.value.load()
return self.value

def save(self, value: Any) -> None:
Expand All @@ -234,11 +236,12 @@ def state(self) -> str | None:
"""
if self.hash:
value = self.load()
if callable(self.hash):
return str(self.hash(self.value))
if isinstance(self.value, str):
return str(hashlib.sha256(self.value.encode()).hexdigest())
if isinstance(self.value, bytes):
return str(hashlib.sha256(self.value).hexdigest())
return str(hash(self.value))
return str(self.hash(value))
if isinstance(value, str):
return str(hashlib.sha256(value.encode()).hexdigest())
if isinstance(value, bytes):
return str(hashlib.sha256(value).hexdigest())
return str(hash(value))
return "0"
6 changes: 0 additions & 6 deletions src/_pytask/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@ def find_closest_ancestor(path: Path, potential_ancestors: Sequence[Path]) -> Pa
return sorted(potential_closest_ancestors, key=lambda x: len(x.parts))[-1]


def find_common_ancestor_of_nodes(*names: str) -> Path:
"""Find the common ancestor from task names and nodes."""
cleaned_names = [Path(name.split("::")[0]) for name in names]
return find_common_ancestor(*cleaned_names)


def find_common_ancestor(*paths: Path) -> Path:
"""Find a common ancestor of many paths."""
return Path(os.path.commonpath(paths))
Expand Down
37 changes: 33 additions & 4 deletions src/_pytask/typing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from __future__ import annotations

import functools
from enum import Enum
from typing import Any
from typing import Final
from typing import Literal
from typing import TYPE_CHECKING

from attr import define
from attrs import define

if TYPE_CHECKING:
from typing_extensions import TypeAlias


__all__ = ["Product", "ProductType"]
Expand All @@ -18,7 +25,29 @@ class ProductType:
"""ProductType: A singleton to mark products in annotations."""


def is_task_function(func: Any) -> bool:
return (callable(func) and hasattr(func, "__name__")) or (
isinstance(func, functools.partial) and hasattr(func.func, "__name__")
def is_task_function(obj: Any) -> bool:
"""Check if an object is a task function."""
return (callable(obj) and hasattr(obj, "__name__")) or (
isinstance(obj, functools.partial) and hasattr(obj.func, "__name__")
)


class _NoDefault(Enum):
"""A singleton for no defaults.
We make this an Enum
1) because it round-trips through pickle correctly (see GH#40397)
2) because mypy does not understand singletons
"""

no_default = "NO_DEFAULT"

def __repr__(self) -> str:
return "<no_default>"


no_default: Final = _NoDefault.no_default
"""The value for missing defaults."""
NoDefault: TypeAlias = Literal[_NoDefault.no_default]
"""The type annotation."""
10 changes: 4 additions & 6 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,10 +714,8 @@ def test_execute_tasks_and_pass_values_only_by_python_nodes(runner, tmp_path):
from typing_extensions import Annotated
from pathlib import Path
node_text = PythonNode(name="text")
def task_create_text() -> Annotated[int, node_text]:
return "This is the text."
Expand All @@ -743,21 +741,21 @@ def test_execute_tasks_via_functional_api(tmp_path):
from pathlib import Path
node_text = PythonNode(name="text", hash=True)
node_text = PythonNode()
def create_text() -> Annotated[int, node_text]:
return "This is the text."
node_file = PathNode.from_path(Path(__file__).parent.joinpath("file.txt"))
def create_file(text: Annotated[int, node_text]) -> Annotated[str, node_file]:
return text
def create_file(content: Annotated[str, node_text]) -> Annotated[str, node_file]:
return content
if __name__ == "__main__":
session = pytask.build(tasks=[create_file, create_text])
assert len(session.tasks) == 2
assert len(session.dag.nodes) == 4
assert len(session.dag.nodes) == 5
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
result = subprocess.run(
Expand Down

0 comments on commit 0fceb2a

Please sign in to comment.