Skip to content

Commit

Permalink
ENH: Prefer using visit_Constant
Browse files Browse the repository at this point in the history
Prefer using `visit_Constant` instead of `visit_Num` and `visit_Str`:
both were deprecated since Python 3.8.

Fixes:
```
tract_querier/tests/test_query_eval.py: 40 warnings
tract_querier/tests/test_query_files.py: 6527 warnings
  /opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/ast.py:407:
 DeprecationWarning: visit_Num is deprecated; add visit_Constant
    return visitor(node)
```

and
```
tract_querier/tests/test_query_eval.py: 1 warning
tract_querier/tests/test_query_files.py: 24 warnings
  /opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/ast.py:407:
 DeprecationWarning: visit_Str is deprecated; add visit_Constant
    return visitor(node)
```

raised for example in:
https://github.com/demianw/tract_querier/actions/runs/12187495803/job/33998410948?pr=61#step:6:75
and
https://github.com/demianw/tract_querier/actions/runs/12187495803/job/33998410948?pr=61#step:6:80

Documentation:
https://docs.python.org/3/library/ast.html#ast.NodeVisitor.visit_Constant
  • Loading branch information
jhlegarreta committed Dec 5, 2024
1 parent e36ee38 commit 0af32a2
Showing 1 changed file with 38 additions and 28 deletions.
66 changes: 38 additions & 28 deletions tract_querier/query_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,6 @@ def visit_UnaryOp(self, node):
raise TractQuerierSyntaxError(
"Syntax error in query line %d" % node.lineno)

def visit_Str(self, node):
query_info = FiberQueryInfo()
for name in fnmatch.filter(self.evaluated_queries_info.keys(), node.s):
query_info.update(self.evaluated_queries_info[name])
return query_info

def visit_Call(self, node):
# Single string argument function
if (
Expand Down Expand Up @@ -558,31 +552,40 @@ def visit_Attribute(self, node):
(node.lineno, query_name)
)

def visit_Num(self, node):
if (
node.n in
self.tractography_spatial_indexing.crossing_labels_tracts
):
tracts = (
self.tractography_spatial_indexing.
crossing_labels_tracts[node.n]
def visit_Constant(self, node):
if isinstance(node.value, (int, float, complex)):
if (
node.n in
self.tractography_spatial_indexing.crossing_labels_tracts
):
tracts = (
self.tractography_spatial_indexing.
crossing_labels_tracts[node.n]
)
else:
tracts = set()

endpoints = (set(), set())
for i in (0, 1):
elt = self.tractography_spatial_indexing.ending_labels_tracts[i]
if node.n in elt:
endpoints[i].update(elt[node.n])

labelset = set((node.n,))
query_info = FiberQueryInfo(
tracts, labelset,
endpoints
)
else:
tracts = set()

endpoints = (set(), set())
for i in (0, 1):
elt = self.tractography_spatial_indexing.ending_labels_tracts[i]
if node.n in elt:
endpoints[i].update(elt[node.n])

labelset = set((node.n,))
tract_info = FiberQueryInfo(
tracts, labelset,
endpoints
)
elif isinstance(node.value, str):
query_info = FiberQueryInfo()
for name in fnmatch.filter(self.evaluated_queries_info.keys(),
node.s):
query_info.update(self.evaluated_queries_info[name])
else:
raise NotImplementedError(f"{node.value} not supported.")

return tract_info
return query_info

def visit_Expr(self, node):
if isinstance(node.value, ast.Name):
Expand Down Expand Up @@ -735,6 +738,13 @@ def visit_Name(self, node):
node
)

# def visit_Constant(self, node):
# if isinstance(node.value, str):
# return ast.copy_location(
# ast.Str(s=node.s.lower()),
# node
# )
#
def visit_Str(self, node):
return ast.copy_location(
ast.Str(s=node.s.lower()),
Expand Down

0 comments on commit 0af32a2

Please sign in to comment.