Skip to content

Commit

Permalink
chore: fix test
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultFy <[email protected]>
  • Loading branch information
ThibaultFy committed Oct 3, 2024
1 parent 9b936ee commit 92aaa72
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 91 deletions.
78 changes: 1 addition & 77 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
import json
import os
import sys
from pathlib import Path
from uuid import uuid4

import pytest

import substra
Expand All @@ -13,10 +7,6 @@
from .fl_interface import FLFunctionInputs
from .fl_interface import FLFunctionOutputs
from .fl_interface import FunctionCategory
from substra.tools.task_resources import TaskResources
from substra.tools.utils import import_module
from substra.tools.workspace import FunctionWorkspace
from tests.tools.utils import OutputIdentifiers


def pytest_configure(config):
Expand All @@ -34,7 +24,7 @@ def client(tmpdir):

@pytest.fixture
def workdir(tmp_path):
d = tmp_path / "substra-workspace"
d = tmp_path / "substra-cli"
d.mkdir()
return d

Expand Down Expand Up @@ -141,69 +131,3 @@ def asset_factory():
@pytest.fixture()
def data_sample(asset_factory):
return asset_factory.create_data_sample()


@pytest.fixture(autouse=True)
def patch_cwd(monkeypatch, workdir):
# this is needed to ensure the workspace is located in a tmpdir
def getcwd():
return str(workdir)

monkeypatch.setattr(os, "getcwd", getcwd)


@pytest.fixture()
def valid_opener_code():
return """
import json
from substratools import Opener
class FakeOpener(Opener):
def get_data(self, folder):
return 'X', list(range(0, 3))
def fake_data(self, n_samples):
return ['Xfake'] * n_samples, [0] * n_samples
"""


@pytest.fixture()
def valid_opener(valid_opener_code):
import_module("opener", valid_opener_code)
yield
del sys.modules["opener"]


@pytest.fixture()
def valid_opener_script(workdir, valid_opener_code):
opener_path = workdir / "my_opener.py"
opener_path.write_text(valid_opener_code)

return str(opener_path)


@pytest.fixture(autouse=True)
def output_model_path(workdir: Path) -> str:
path = workdir / str(uuid4())
yield path
if path.exists():
os.remove(path)


@pytest.fixture(autouse=True)
def output_model_path_2(workdir: Path) -> str:
path = workdir / str(uuid4())
yield path
if path.exists():
os.remove(path)


@pytest.fixture()
def valid_function_workspace(output_model_path: str) -> FunctionWorkspace:
workspace_outputs = TaskResources(
json.dumps([{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}])
)

workspace = FunctionWorkspace(outputs=workspace_outputs)

return workspace
85 changes: 85 additions & 0 deletions tests/tools/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import json
import os
import sys
from pathlib import Path
from uuid import uuid4

import pytest

from substratools.task_resources import TaskResources
from substratools.utils import import_module
from substratools.workspace import FunctionWorkspace
from tests.tools.utils import OutputIdentifiers


@pytest.fixture
def workdir(tmp_path):
d = tmp_path / "substra-workspace"
d.mkdir()
return d


# @pytest.fixture(autouse=True)
# def patch_cwd(monkeypatch, workdir):
# # this is needed to ensure the workspace is located in a tmpdir
# def getcwd():
# return str(workdir)

# monkeypatch.setattr(os, "getcwd", getcwd)


@pytest.fixture()
def valid_opener_code():
return """
import json
from substratools import Opener
class FakeOpener(Opener):
def get_data(self, folder):
return 'X', list(range(0, 3))
def fake_data(self, n_samples):
return ['Xfake'] * n_samples, [0] * n_samples
"""


@pytest.fixture()
def valid_opener(valid_opener_code):
import_module("opener", valid_opener_code)
yield
del sys.modules["opener"]


@pytest.fixture()
def valid_opener_script(workdir, valid_opener_code):
opener_path = workdir / "my_opener.py"
opener_path.write_text(valid_opener_code)

return str(opener_path)


@pytest.fixture(autouse=True)
def output_model_path(workdir: Path) -> str:
path = workdir / str(uuid4())
yield path
if path.exists():
os.remove(path)


@pytest.fixture(autouse=True)
def output_model_path_2(workdir: Path) -> str:
path = workdir / str(uuid4())
yield path
if path.exists():
os.remove(path)


@pytest.fixture()
def valid_function_workspace(output_model_path: str) -> FunctionWorkspace:
workspace_outputs = TaskResources(
json.dumps([{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}])
)

workspace = FunctionWorkspace(outputs=workspace_outputs)

return workspace
28 changes: 14 additions & 14 deletions tests/tools/test_opener.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,26 @@
from substratools.workspace import DEFAULT_INPUT_DATA_FOLDER_PATH


@pytest.fixture
def tmp_cwd(tmp_path):
# create a temporary current working directory
new_dir = tmp_path / "workspace"
new_dir.mkdir()
# @pytest.fixture
# def tmp_cwd(tmp_path):
# # create a temporary current working directory
# new_dir = tmp_path / "workspace"
# new_dir.mkdir()

old_dir = os.getcwd()
os.chdir(new_dir)
# old_dir = os.getcwd()
# os.chdir(new_dir)

yield new_dir
# yield new_dir

os.chdir(old_dir)
# os.chdir(old_dir)


def test_load_opener_not_found(tmp_cwd):
def test_load_opener_not_found():
with pytest.raises(ImportError):
load_from_module()


def test_load_invalid_opener(tmp_cwd):
def test_load_invalid_opener():
invalid_script = """
def get_data():
raise NotImplementedError
Expand All @@ -42,7 +42,7 @@ def get_data():
load_from_module()


def test_load_opener_as_class(tmp_cwd):
def test_load_opener_as_class():
script = """
from substratools import Opener
class MyOpener(Opener):
Expand All @@ -58,7 +58,7 @@ def fake_data(self, n_samples):
assert o.get_data() == "data_class"


def test_load_opener_from_path(tmp_cwd, valid_opener_code):
def test_load_opener_from_path(valid_opener_code):
dirpath = tmp_cwd / "myopener"
dirpath.mkdir()
path = dirpath / "my_opener.py"
Expand All @@ -74,7 +74,7 @@ def test_load_opener_from_path(tmp_cwd, valid_opener_code):
assert o.get_data()[0] == "X"


def test_opener_check_folders(tmp_cwd):
def test_opener_check_folders():
script = """
from substratools import Opener
class MyOpener(Opener):
Expand Down

0 comments on commit 92aaa72

Please sign in to comment.