Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Use pytest for testing instead of nose #59

Merged
merged 2 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ jobs:

- name: Run tests
run: |
nosetests -v
pytest -v
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ doc = [
]
test = [
"coverage",
"nose",
"pytest",
]

[project.urls]
Expand Down
41 changes: 20 additions & 21 deletions tract_querier/tests/test_query_eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .. import query_processor
from nose.tools import assert_true, assert_equal

from numpy import random
import ast
Expand Down Expand Up @@ -39,7 +38,7 @@ def __init__(
def test_assign():
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=0"))
assert_true((
assert ((
'A' in query_evaluator.queries_to_save and
query_evaluator.evaluated_queries_info['A'].tracts == labels_tracts[0] and
query_evaluator.evaluated_queries_info['A'].labels == set((0,))
Expand All @@ -49,7 +48,7 @@ def test_assign():
def test_assign_attr():
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("a.left=0"))
assert_true((
assert ((
'a.left' in query_evaluator.queries_to_save and
query_evaluator.evaluated_queries_info['a.left'].tracts == labels_tracts[0] and
query_evaluator.evaluated_queries_info['a.left'].labels == set((0,))
Expand Down Expand Up @@ -87,8 +86,8 @@ def test_assign_side():

query_evaluator.visit(ast.parse(query))

assert_equal({k: v.labels for k, v in query_evaluator.evaluated_queries_info.items()}, queries_labels)
assert_equal({k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()}, queries_tracts)
assert {k: v.labels for k, v in query_evaluator.evaluated_queries_info.items()} == queries_labels
assert {k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()} == queries_tracts


def test_assign_str():
Expand Down Expand Up @@ -120,8 +119,8 @@ def test_assign_str():

query_evaluator.visit(ast.parse(query))

assert_equal({k: v.labels for k, v in query_evaluator.evaluated_queries_info.items()}, queries_labels)
assert_equal({k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()}, queries_tracts)
assert {k: v.labels for k, v in query_evaluator.evaluated_queries_info.items()} == queries_labels
assert {k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()} == queries_tracts


def test_for_list():
Expand Down Expand Up @@ -151,7 +150,7 @@ def test_for_list():

query_evaluator.visit(ast.parse(query))

assert_equal({k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()}, queries_tracts)
assert {k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()} == queries_tracts


def test_for_str():
Expand Down Expand Up @@ -181,13 +180,13 @@ def test_for_str():

query_evaluator.visit(ast.parse(query))

assert_equal({k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()}, queries_tracts)
assert {k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()} == queries_tracts


def test_add():
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=0+1"))
assert_true((
assert ((
'A' in query_evaluator.queries_to_save and
query_evaluator.evaluated_queries_info['A'].tracts == labels_tracts[0].union(labels_tracts[1]) and
query_evaluator.evaluated_queries_info['A'].labels == set((0, 1))
Expand All @@ -197,7 +196,7 @@ def test_add():
def test_mult():
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=0 * 1"))
assert_true((
assert ((
'A' in query_evaluator.queries_to_save and
query_evaluator.evaluated_queries_info['A'].tracts == labels_tracts[0].intersection(labels_tracts[1]) and
query_evaluator.evaluated_queries_info['A'].labels == set((0, 1))
Expand All @@ -207,7 +206,7 @@ def test_mult():
def test_sub():
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=(0 + 1) - 1"))
assert_true((
assert ((
'A' in query_evaluator.queries_to_save and
query_evaluator.evaluated_queries_info['A'].tracts == labels_tracts[0].difference(labels_tracts[1]) and
query_evaluator.evaluated_queries_info['A'].labels == set((0,))
Expand All @@ -217,7 +216,7 @@ def test_sub():
def test_or():
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=0 or 1"))
assert_true((
assert ((
'A' in query_evaluator.queries_to_save and
query_evaluator.evaluated_queries_info['A'].tracts == labels_tracts[0].union(labels_tracts[1]) and
query_evaluator.evaluated_queries_info['A'].labels == set((0, 1))
Expand All @@ -227,7 +226,7 @@ def test_or():
def test_and():
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=0 and 1"))
assert_true((
assert ((
'A' in query_evaluator.queries_to_save and
query_evaluator.evaluated_queries_info['A'].tracts == labels_tracts[0].intersection(labels_tracts[1]) and
query_evaluator.evaluated_queries_info['A'].labels == set((0, 1))
Expand All @@ -237,7 +236,7 @@ def test_and():
def test_not_in():
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=0 or 1 not in 1"))
assert_true((
assert ((
'A' in query_evaluator.queries_to_save and
query_evaluator.evaluated_queries_info['A'].tracts == labels_tracts[0].difference(labels_tracts[1]) and
query_evaluator.evaluated_queries_info['A'].labels == set((0,))
Expand All @@ -247,7 +246,7 @@ def test_not_in():
def test_only_sign():
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=~0"))
assert_true((
assert ((
'A' in query_evaluator.queries_to_save and
query_evaluator.evaluated_queries_info['A'].tracts == tract_in_label_0_uniquely and
query_evaluator.evaluated_queries_info['A'].labels == set((0,))
Expand All @@ -257,7 +256,7 @@ def test_only_sign():
def test_only():
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=only(0)"))
assert_true((
assert ((
'A' in query_evaluator.queries_to_save and
query_evaluator.evaluated_queries_info['A'].tracts == tract_in_label_0_uniquely and
query_evaluator.evaluated_queries_info['A'].labels == set((0,))
Expand All @@ -267,7 +266,7 @@ def test_only():
def test_unsaved_query():
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A|=0"))
assert_true((
assert ((
'A' not in query_evaluator.queries_to_save and
query_evaluator.evaluated_queries_info['A'].tracts == labels_tracts[0] and
query_evaluator.evaluated_queries_info['A'].labels == set((0,))
Expand All @@ -277,7 +276,7 @@ def test_unsaved_query():
def test_symbolic_assignment():
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=0; B=A"))
assert_true((
assert ((
'B' in query_evaluator.queries_to_save and
query_evaluator.evaluated_queries_info['B'].tracts == labels_tracts[0] and
query_evaluator.evaluated_queries_info['B'].labels == set((0,))
Expand All @@ -287,7 +286,7 @@ def test_symbolic_assignment():
def test_unarySub():
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("B=0; A=-B"))
assert_true((
assert ((
'A' in query_evaluator.queries_to_save and
query_evaluator.evaluated_queries_info['A'].tracts == tracts_in_all_but_0 and
query_evaluator.evaluated_queries_info['A'].labels == set(labels_tracts.keys()).difference((0,))
Expand All @@ -297,7 +296,7 @@ def test_unarySub():
def test_not():
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A= not 0"))
assert_true((
assert ((
'A' in query_evaluator.queries_to_save and
query_evaluator.evaluated_queries_info['A'].tracts == tracts_in_all_but_0 and
query_evaluator.evaluated_queries_info['A'].labels == set(labels_tracts.keys()).difference((0,))
Expand Down
19 changes: 11 additions & 8 deletions tract_querier/tests/test_query_files.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from .. import queries_preprocess, queries_syntax_check
from nose.tools import nottest

import os
import fnmatch
import pytest


def test_query_files(
folder=os.path.join(os.path.dirname(__file__), '..', 'data')
):
files = fnmatch.filter(os.listdir(folder), '*qry')
for f in files:
yield query_file_test, os.path.join(folder, f), [folder]
@pytest.fixture
def data_folder():
return os.path.join(os.path.dirname(__file__), '..', 'data')

@pytest.mark.parametrize("filename", [
pytest.param(os.path.join(os.path.join(os.path.dirname(__file__), '..', 'data'), f), id=f)
for f in fnmatch.filter(os.listdir(os.path.join(os.path.dirname(__file__), '..', 'data')), '*qry')
])
def test_query_files(data_folder, filename):
query_file_test(filename, [data_folder])


@nottest
def query_file_test(filename, include_folders):
buf = open(filename).read()
query_body = queries_preprocess(
Expand Down
19 changes: 7 additions & 12 deletions tract_querier/tests/test_query_rewrite.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from .. import query_processor

from nose.tools import assert_equal, assert_not_equal
from unittest import expectedFailure, skip
import pytest

import ast

import parser
import token
import symbol


def match(pattern, data, vars=None):
if vars is None:
Expand All @@ -27,7 +22,7 @@ def match(pattern, data, vars=None):
return same, vars


@skip
@pytest.mark.skip()
def test_rewrite_notin_precedence():
code1 = "a and b not in c"
code2 = "(a and b) not in c"
Expand All @@ -48,10 +43,10 @@ def test_rewrite_notin_precedence():
rw.visit(tree2_rw)
rw.visit(tree3_rw)

assert_not_equal(ast.dump(tree1), ast.dump(tree2))
assert_equal(ast.dump(tree2), ast.dump(tree2_rw))
assert_equal(ast.dump(tree1_rw), ast.dump(tree2))
assert ast.dump(tree1) != ast.dump(tree2)
assert ast.dump(tree2) == ast.dump(tree2_rw)
assert ast.dump(tree1_rw) == ast.dump(tree2)

assert_equal(ast.dump(tree3), ast.dump(tree3_rw))
assert ast.dump(tree3) == ast.dump(tree3_rw)

assert_equal(ast.dump(tree1), ast.dump(tree3_rw))
assert ast.dump(tree1) == ast.dump(tree3_rw)
22 changes: 10 additions & 12 deletions tract_querier/tests/test_scripts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from nose.tools import assert_equal, assert_greater, assert_in, assert_is_not_none, assert_true

import os
from os import path
import re
Expand Down Expand Up @@ -38,8 +36,8 @@ def test_tract_querier_help():
)
popen.wait()
stderr_text = ''.join(popen.stderr.readlines())
assert_in('error: incorrect number of arguments', stderr_text)
assert_greater(popen.returncode, 0)
assert 'error: incorrect number of arguments' in stderr_text
assert popen.returncode > 0

def test_tract_math_help():
popen = subprocess.Popen(
Expand All @@ -49,8 +47,8 @@ def test_tract_math_help():
)
popen.wait()
stderr_text = ''.join(popen.stderr.readlines())
assert_in('error: too few arguments', stderr_text)
assert_greater(popen.returncode, 0)
assert 'error: too few arguments' in stderr_text
assert popen.returncode > 0

def test_tract_math_count():
popen = subprocess.Popen(
Expand All @@ -60,8 +58,8 @@ def test_tract_math_count():
)
popen.wait()
stdout_text = ''.join(popen.stdout.readlines())
assert_is_not_none(re.search('[^0-9]6783[^0-9]', stdout_text))
assert_equal(popen.returncode, 0)
assert re.search('[^0-9]6783[^0-9]', stdout_text) is not None
assert popen.returncode == 0

def test_tract_querier_query():
output_prefix = '%s/test' % TEST_DATA.dirname
Expand All @@ -74,9 +72,9 @@ def test_tract_querier_query():
)
popen.wait()
stdout_text = ''.join(popen.stdout.readlines())
assert_in('uncinate.left: 000102', stdout_text)
assert_in('uncinate.right: 000000', stdout_text)
assert_true(path.exists(output_prefix + '_uncinate.left.trk'))
assert_equal(popen.returncode, 0)
assert 'uncinate.left: 000102' in stdout_text
assert 'uncinate.right: 000000' in stdout_text
assert path.exists(output_prefix + '_uncinate.left.trk')
assert popen.returncode == 0
if path.exists(output_prefix + '_uncinate.left.trk'):
os.remove(output_prefix + '_uncinate.left.trk')
12 changes: 3 additions & 9 deletions tract_querier/tractography/tests/test_tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
except ImportError:
VTK = False

from nose.tools import with_setup
import pytest
import copy
from itertools import chain

Expand Down Expand Up @@ -61,7 +61,7 @@ def equal_tractography(a, b):
)


def setup(*args, **kwargs):
def setup_module(*args, **kwargs):
global dimensions
global tracts
global tracts_data
Expand Down Expand Up @@ -100,15 +100,13 @@ def setup(*args, **kwargs):
tractography = Tractography(tracts, tracts_data)


@with_setup(setup)
def test_creation():
assert(equal_tracts(tractography.tracts(), tracts))
assert(equal_tracts_data(tractography.tracts_data(), tracts_data))
assert(not tractography.are_tracts_subsampled())
assert(not tractography.are_tracts_filtered())


@with_setup(setup)
def test_subsample_tracts():
tractography.subsample_tracts(5)

Expand All @@ -133,7 +131,6 @@ def test_subsample_tracts():
assert(not tractography.are_tracts_filtered())


@with_setup(setup)
def test_append():
old_tracts = copy.deepcopy(tractography.tracts())
new_data = {}
Expand All @@ -147,7 +144,6 @@ def test_append():


if VTK:
@with_setup(setup)
def test_saveload_vtk():
import tempfile
import os
Expand All @@ -164,7 +160,7 @@ def test_saveload_vtk():

os.remove(fname)

@with_setup(setup)

def test_saveload_vtp():
import tempfile
import os
Expand All @@ -179,7 +175,6 @@ def test_saveload_vtp():
os.remove(fname)


@with_setup(setup)
def test_saveload_trk():
import tempfile
import os
Expand Down Expand Up @@ -208,7 +203,6 @@ def test_saveload_trk():
os.remove(fname)


@with_setup(setup)
def test_saveload():
import tempfile
import os
Expand Down
Loading