Skip to content

Commit

Permalink
Catch objects pretending to be PTask. (#508)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasraabe authored Dec 1, 2023
1 parent d846c5c commit a06d948
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 17 deletions.
5 changes: 4 additions & 1 deletion docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ chronological order. Releases follow [semantic versioning](https://semver.org/)
releases are available on [PyPI](https://pypi.org/project/pytask) and
[Anaconda.org](https://anaconda.org/conda-forge/pytask).

## 0.4.3 - 2023-11-xx
## 0.4.3 - 2023-12-01

- {pull}`483` simplifies the teardown of a task.
- {pull}`484` raises more informative error when directories instead of files are used
Expand All @@ -26,6 +26,9 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
- {pull}`498` fixes an error when using {class}`~pytask.Task` and
{class}`~pytask.TaskWithoutPath` in task modules.
- {pull}`500` refactors the dependencies for tests.
- {pull}`501` removes `MetaNode`.
- {pull}`508` catches objects that pretend to be a {class}`~pytask.PTask`. Fixes
{issue}`507`.

## 0.4.2 - 2023-11-08

Expand Down
28 changes: 22 additions & 6 deletions src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from _pytask.console import get_file
from _pytask.console import is_jupyter
from _pytask.exceptions import CollectionError
from _pytask.mark import MarkGenerator
from _pytask.mark_utils import get_all_marks
from _pytask.mark_utils import has_mark
from _pytask.node_protocols import PNode
Expand Down Expand Up @@ -176,10 +175,7 @@ def pytask_collect_file(

collected_reports = []
for name, obj in inspect.getmembers(mod):
# Skip mark generator since it overrides __getattr__ and seems like any
# object. Happens when people do ``from pytask import mark`` and
# ``@mark.x``.
if isinstance(obj, MarkGenerator):
if _is_filtered_object(obj):
continue

# Ensures that tasks with this decorator are only collected once.
Expand All @@ -196,6 +192,26 @@ def pytask_collect_file(
return None


def _is_filtered_object(obj: Any) -> bool:
"""Filter some objects that are only causing harm later on.
See :issue:`507`.
"""
# Filter :class:`pytask.Task` and :class:`pytask.TaskWithoutPath` objects.
if isinstance(obj, PTask) and inspect.isclass(obj):
return True

# Filter objects overwriting the ``__getattr__`` method like :class:`pytask.mark` or
# ``from ibis import _``.
attr_name = "attr_that_definitely_does_not_exist"
if hasattr(obj, attr_name) and not bool(
inspect.getattr_static(obj, attr_name, False)
):
return True
return False


@hookimpl
def pytask_collect_task_protocol(
session: Session, path: Path | None, name: str, obj: Any
Expand Down Expand Up @@ -279,7 +295,7 @@ def pytask_collect_task(
markers=markers,
attributes={"collection_id": collection_id, "after": after},
)
if isinstance(obj, PTask) and not inspect.isclass(obj):
if isinstance(obj, PTask):
return obj
return None

Expand Down
11 changes: 4 additions & 7 deletions src/_pytask/mark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""
from __future__ import annotations

import inspect
from typing import Any
from typing import TYPE_CHECKING

Expand All @@ -19,12 +18,10 @@

def get_all_marks(obj_or_task: Any | PTask) -> list[Mark]:
"""Get all marks from a callable or task."""
if isinstance(obj_or_task, PTask) and not inspect.isclass(obj_or_task):
marks = obj_or_task.markers
else:
obj = obj_or_task
marks = obj.pytask_meta.markers if hasattr(obj, "pytask_meta") else []
return marks
if isinstance(obj_or_task, PTask):
return obj_or_task.markers
obj = obj_or_task
return obj.pytask_meta.markers if hasattr(obj, "pytask_meta") else []


def set_marks(obj_or_task: Any | PTask, marks: list[Mark]) -> Any | PTask:
Expand Down
2 changes: 1 addition & 1 deletion src/_pytask/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from _pytask.mark import Mark


__all__ = ["PathNode", "PythonNode", "Task", "TaskWithoutPath"]
__all__ = ["PathNode", "PickleNode", "PythonNode", "Task", "TaskWithoutPath"]


@define(kw_only=True)
Expand Down
8 changes: 7 additions & 1 deletion src/_pytask/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from typing_extensions import TypeAlias


__all__ = ["Product", "ProductType"]
__all__ = [
"NoDefault",
"Product",
"ProductType",
"is_task_function",
"no_default",
]


@define(frozen=True)
Expand Down
7 changes: 6 additions & 1 deletion tests/test_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,12 @@ def task_mixed(): pass
@pytest.mark.end_to_end()
def test_module_can_be_collected(runner, tmp_path):
source = """
from pytask import Task, TaskWithoutPath
from pytask import Task, TaskWithoutPath, mark
class C:
def __getattr__(self, name):
return C()
c = C()
"""
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))

Expand Down

0 comments on commit a06d948

Please sign in to comment.