diff --git a/tract_querier/query_processor.py b/tract_querier/query_processor.py index 026dbb2..84010a1 100644 --- a/tract_querier/query_processor.py +++ b/tract_querier/query_processor.py @@ -1,4 +1,5 @@ import ast +import numbers from os import path from copy import deepcopy from operator import lt, gt @@ -240,12 +241,6 @@ def visit_UnaryOp(self, node): raise TractQuerierSyntaxError( "Syntax error in query line %d" % node.lineno) - def visit_Str(self, node): - query_info = FiberQueryInfo() - for name in fnmatch.filter(self.evaluated_queries_info.keys(), node.s): - query_info.update(self.evaluated_queries_info[name]) - return query_info - def visit_Call(self, node): # Single string argument function if ( @@ -558,31 +553,40 @@ def visit_Attribute(self, node): (node.lineno, query_name) ) - def visit_Num(self, node): - if ( - node.n in - self.tractography_spatial_indexing.crossing_labels_tracts - ): - tracts = ( - self.tractography_spatial_indexing. - crossing_labels_tracts[node.n] + def visit_Constant(self, node): + if isinstance(node.value, numbers.Number): + if ( + node.n in + self.tractography_spatial_indexing.crossing_labels_tracts + ): + tracts = ( + self.tractography_spatial_indexing. + crossing_labels_tracts[node.n] + ) + else: + tracts = set() + + endpoints = (set(), set()) + for i in (0, 1): + elt = self.tractography_spatial_indexing.ending_labels_tracts[i] + if node.n in elt: + endpoints[i].update(elt[node.n]) + + labelset = set((node.n,)) + query_info = FiberQueryInfo( + tracts, labelset, + endpoints ) - else: - tracts = set() - endpoints = (set(), set()) - for i in (0, 1): - elt = self.tractography_spatial_indexing.ending_labels_tracts[i] - if node.n in elt: - endpoints[i].update(elt[node.n]) - - labelset = set((node.n,)) - tract_info = FiberQueryInfo( - tracts, labelset, - endpoints - ) + elif isinstance(node.value, str): + query_info = FiberQueryInfo() + for name in fnmatch.filter(self.evaluated_queries_info.keys(), + node.s): + query_info.update(self.evaluated_queries_info[name]) + else: + raise NotImplementedError(f"{node.value} not supported.") - return tract_info + return query_info def visit_Expr(self, node): if isinstance(node.value, ast.Name): @@ -735,11 +739,14 @@ def visit_Name(self, node): node ) - def visit_Str(self, node): - return ast.copy_location( - ast.Str(s=node.s.lower()), - node - ) + def visit_Constant(self, node): + if isinstance(node.s, str): + return ast.copy_location( + ast.Constant(node.s.lower()), + node + ) + else: + return self.generic_visit(node) def visit_Import(self, node): try: diff --git a/tract_querier/tests/test_query_files.py b/tract_querier/tests/test_query_files.py index abaa3d4..6cd669c 100644 --- a/tract_querier/tests/test_query_files.py +++ b/tract_querier/tests/test_query_files.py @@ -9,9 +9,12 @@ def data_folder(): return os.path.join(os.path.dirname(__file__), '..', 'data') +#@pytest.mark.parametrize("filename", [ +# pytest.param(os.path.join(os.path.join(os.path.dirname(__file__), '..', 'data'), f), id=f) +# for f in fnmatch.filter(os.listdir(os.path.join(os.path.dirname(__file__), '..', 'data')), '*qry') +#]) @pytest.mark.parametrize("filename", [ - pytest.param(os.path.join(os.path.join(os.path.dirname(__file__), '..', 'data'), f), id=f) - for f in fnmatch.filter(os.listdir(os.path.join(os.path.dirname(__file__), '..', 'data')), '*qry') + "/home/jhlegarreta/src/tract_querier/tract_querier/data/debug_mori_queries_short.qry" ]) def test_query_files(data_folder, filename): query_file_test(filename, [data_folder])