From 518c8227eb1a1cb294e0593ad07c3d220b84f40a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Thu, 5 Dec 2024 15:49:57 -0500 Subject: [PATCH] ENH: Prefer using `visit_Constant` Prefer using `visit_Constant` instead of `visit_Num` and `visit_Str`: both were deprecated since Python 3.8. Fixes: ``` tract_querier/tests/test_query_eval.py: 40 warnings tract_querier/tests/test_query_files.py: 6527 warnings /opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/ast.py:407: DeprecationWarning: visit_Num is deprecated; add visit_Constant return visitor(node) ``` and ``` tract_querier/tests/test_query_eval.py: 1 warning tract_querier/tests/test_query_files.py: 24 warnings /opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/ast.py:407: DeprecationWarning: visit_Str is deprecated; add visit_Constant return visitor(node) ``` raised for example in: https://github.com/demianw/tract_querier/actions/runs/12187495803/job/33998410948?pr=61#step:6:75 and https://github.com/demianw/tract_querier/actions/runs/12187495803/job/33998410948?pr=61#step:6:80 Documentation: https://docs.python.org/3/library/ast.html#ast.NodeVisitor.visit_Constant --- tract_querier/query_processor.py | 61 +++++++++++++++++--------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/tract_querier/query_processor.py b/tract_querier/query_processor.py index 026dbb2..92705ee 100644 --- a/tract_querier/query_processor.py +++ b/tract_querier/query_processor.py @@ -240,12 +240,6 @@ def visit_UnaryOp(self, node): raise TractQuerierSyntaxError( "Syntax error in query line %d" % node.lineno) - def visit_Str(self, node): - query_info = FiberQueryInfo() - for name in fnmatch.filter(self.evaluated_queries_info.keys(), node.s): - query_info.update(self.evaluated_queries_info[name]) - return query_info - def visit_Call(self, node): # Single string argument function if ( @@ -558,31 +552,40 @@ def visit_Attribute(self, node): (node.lineno, query_name) ) - def visit_Num(self, node): - if ( - node.n in - self.tractography_spatial_indexing.crossing_labels_tracts - ): - tracts = ( - self.tractography_spatial_indexing. - crossing_labels_tracts[node.n] + def visit_Constant(self, node): + if isinstance(node.value, (int, float, complex)): + if ( + node.n in + self.tractography_spatial_indexing.crossing_labels_tracts + ): + tracts = ( + self.tractography_spatial_indexing. + crossing_labels_tracts[node.n] + ) + else: + tracts = set() + + endpoints = (set(), set()) + for i in (0, 1): + elt = self.tractography_spatial_indexing.ending_labels_tracts[i] + if node.n in elt: + endpoints[i].update(elt[node.n]) + + labelset = set((node.n,)) + query_info = FiberQueryInfo( + tracts, labelset, + endpoints ) - else: - tracts = set() - - endpoints = (set(), set()) - for i in (0, 1): - elt = self.tractography_spatial_indexing.ending_labels_tracts[i] - if node.n in elt: - endpoints[i].update(elt[node.n]) - labelset = set((node.n,)) - tract_info = FiberQueryInfo( - tracts, labelset, - endpoints - ) + elif isinstance(node.value, str): + query_info = FiberQueryInfo() + for name in fnmatch.filter(self.evaluated_queries_info.keys(), + node.s): + query_info.update(self.evaluated_queries_info[name]) + else: + raise NotImplementedError(f"{node.value} not supported.") - return tract_info + return query_info def visit_Expr(self, node): if isinstance(node.value, ast.Name): @@ -735,7 +738,7 @@ def visit_Name(self, node): node ) - def visit_Str(self, node): + def visit_Constant(self, node): return ast.copy_location( ast.Str(s=node.s.lower()), node