diff --git a/tests/conftest.py b/tests/conftest.py index 828a3991..158a443b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,9 @@ +import json +import os +import sys +from pathlib import Path +from uuid import uuid4 + import pytest import substra @@ -7,6 +13,10 @@ from .fl_interface import FLFunctionInputs from .fl_interface import FLFunctionOutputs from .fl_interface import FunctionCategory +from substra.tools.task_resources import TaskResources +from substra.tools.utils import import_module +from substra.tools.workspace import FunctionWorkspace +from tests.tools.utils import OutputIdentifiers def pytest_configure(config): @@ -131,3 +141,69 @@ def asset_factory(): @pytest.fixture() def data_sample(asset_factory): return asset_factory.create_data_sample() + + +@pytest.fixture(autouse=True) +def patch_cwd(monkeypatch, workdir): + # this is needed to ensure the workspace is located in a tmpdir + def getcwd(): + return str(workdir) + + monkeypatch.setattr(os, "getcwd", getcwd) + + +@pytest.fixture() +def valid_opener_code(): + return """ +import json +from substratools import Opener + +class FakeOpener(Opener): + def get_data(self, folder): + return 'X', list(range(0, 3)) + + def fake_data(self, n_samples): + return ['Xfake'] * n_samples, [0] * n_samples +""" + + +@pytest.fixture() +def valid_opener(valid_opener_code): + import_module("opener", valid_opener_code) + yield + del sys.modules["opener"] + + +@pytest.fixture() +def valid_opener_script(workdir, valid_opener_code): + opener_path = workdir / "my_opener.py" + opener_path.write_text(valid_opener_code) + + return str(opener_path) + + +@pytest.fixture(autouse=True) +def output_model_path(workdir: Path) -> str: + path = workdir / str(uuid4()) + yield path + if path.exists(): + os.remove(path) + + +@pytest.fixture(autouse=True) +def output_model_path_2(workdir: Path) -> str: + path = workdir / str(uuid4()) + yield path + if path.exists(): + os.remove(path) + + +@pytest.fixture() +def valid_function_workspace(output_model_path: str) -> FunctionWorkspace: + workspace_outputs = TaskResources( + json.dumps([{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}]) + ) + + workspace = FunctionWorkspace(outputs=workspace_outputs) + + return workspace diff --git a/tests/tools/test_aggregatealgo.py b/tests/tools/test_aggregatealgo.py index 9b814db3..a6117cd9 100644 --- a/tests/tools/test_aggregatealgo.py +++ b/tests/tools/test_aggregatealgo.py @@ -13,8 +13,8 @@ from substratools.task_resources import TaskResources from substratools.workspace import FunctionWorkspace from tests.tools import utils -from tests.utils import InputIdentifiers -from tests.utils import OutputIdentifiers +from tests.tools.utils import InputIdentifiers +from tests.tools.utils import OutputIdentifiers @pytest.fixture(autouse=True) diff --git a/tests/tools/test_compositealgo.py b/tests/tools/test_compositealgo.py index 89beb6d9..a8f9aaf6 100644 --- a/tests/tools/test_compositealgo.py +++ b/tests/tools/test_compositealgo.py @@ -12,8 +12,8 @@ from substratools.task_resources import TaskResources from substratools.workspace import FunctionWorkspace from tests.tools import utils -from tests.utils import InputIdentifiers -from tests.utils import OutputIdentifiers +from tests.tools.utils import InputIdentifiers +from tests.tools.utils import OutputIdentifiers @pytest.fixture(autouse=True) diff --git a/tests/tools/test_function.py b/tests/tools/test_function.py index c76c6170..41c67db4 100644 --- a/tests/tools/test_function.py +++ b/tests/tools/test_function.py @@ -17,8 +17,8 @@ from substratools.task_resources import TaskResources from substratools.workspace import FunctionWorkspace from tests.tools import utils -from tests.utils import InputIdentifiers -from tests.utils import OutputIdentifiers +from tests.tools.utils import InputIdentifiers +from tests.tools.utils import OutputIdentifiers @pytest.fixture(autouse=True) diff --git a/tests/tools/test_metrics.py b/tests/tools/test_metrics.py index c8d4944c..1018d37f 100644 --- a/tests/tools/test_metrics.py +++ b/tests/tools/test_metrics.py @@ -14,8 +14,8 @@ from substratools.task_resources import TaskResources from substratools.workspace import FunctionWorkspace from tests.tools import utils -from tests.utils import InputIdentifiers -from tests.utils import OutputIdentifiers +from tests.tools.utils import InputIdentifiers +from tests.tools.utils import OutputIdentifiers @pytest.fixture() diff --git a/tests/tools/test_workflow.py b/tests/tools/test_workflow.py index 726439c5..fc924c51 100644 --- a/tests/tools/test_workflow.py +++ b/tests/tools/test_workflow.py @@ -11,8 +11,8 @@ from substratools.utils import import_module from substratools.workspace import FunctionWorkspace from tests.tools import utils -from tests.utils import InputIdentifiers -from tests.utils import OutputIdentifiers +from tests.tools.utils import InputIdentifiers +from tests.tools.utils import OutputIdentifiers @pytest.fixture diff --git a/tests/tools/tools_conftest.py b/tests/tools/tools_conftest.py index a51e6fe9..c0b69f9b 100644 --- a/tests/tools/tools_conftest.py +++ b/tests/tools/tools_conftest.py @@ -9,7 +9,7 @@ from substratools.task_resources import TaskResources from substratools.utils import import_module from substratools.workspace import FunctionWorkspace -from tests.utils import OutputIdentifiers +from tests.tools.utils import OutputIdentifiers @pytest.fixture