Skip to content

Commit

Permalink
chore: fix test
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultFy <[email protected]>
  • Loading branch information
ThibaultFy committed Oct 3, 2024
1 parent de47dd9 commit 6f7576b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ authors = [{ name = "Owkin, Inc." }]
[project.optional-dependencies]
dev = [
"pandas",
"numpy",
"pytest",
"pytest-cov",
"pytest-mock",
Expand Down
28 changes: 14 additions & 14 deletions tests/tools/test_opener.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,26 @@
from substratools.workspace import DEFAULT_INPUT_DATA_FOLDER_PATH


# @pytest.fixture
# def tmp_cwd(tmp_path):
# # create a temporary current working directory
# new_dir = tmp_path / "workspace"
# new_dir.mkdir()
@pytest.fixture
def tmp_cwd(tmp_path):
# create a temporary current working directory
new_dir = tmp_path / "workspace"
new_dir.mkdir()

# old_dir = os.getcwd()
# os.chdir(new_dir)
old_dir = os.getcwd()
os.chdir(new_dir)

# yield new_dir
yield new_dir

# os.chdir(old_dir)
os.chdir(old_dir)


def test_load_opener_not_found():
def test_load_opener_not_found(tmp_cwd):
with pytest.raises(ImportError):
load_from_module()


def test_load_invalid_opener():
def test_load_invalid_opener(tmp_cwd):
invalid_script = """
def get_data():
raise NotImplementedError
Expand All @@ -42,7 +42,7 @@ def get_data():
load_from_module()


def test_load_opener_as_class():
def test_load_opener_as_class(tmp_cwd):
script = """
from substratools import Opener
class MyOpener(Opener):
Expand All @@ -58,7 +58,7 @@ def fake_data(self, n_samples):
assert o.get_data() == "data_class"


def test_load_opener_from_path(valid_opener_code):
def test_load_opener_from_path(tmp_cwd, valid_opener_code):
dirpath = tmp_cwd / "myopener"
dirpath.mkdir()
path = dirpath / "my_opener.py"
Expand All @@ -74,7 +74,7 @@ def test_load_opener_from_path(valid_opener_code):
assert o.get_data()[0] == "X"


def test_opener_check_folders():
def test_opener_check_folders(tmp_cwd):
script = """
from substratools import Opener
class MyOpener(Opener):
Expand Down

0 comments on commit 6f7576b

Please sign in to comment.