Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Prefer using visit_Constant #63

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 40 additions & 33 deletions tract_querier/query_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import numbers
from os import path
from copy import deepcopy
from operator import lt, gt
Expand Down Expand Up @@ -240,12 +241,6 @@ def visit_UnaryOp(self, node):
raise TractQuerierSyntaxError(
"Syntax error in query line %d" % node.lineno)

def visit_Str(self, node):
query_info = FiberQueryInfo()
for name in fnmatch.filter(self.evaluated_queries_info.keys(), node.s):
query_info.update(self.evaluated_queries_info[name])
return query_info

def visit_Call(self, node):
# Single string argument function
if (
Expand Down Expand Up @@ -558,31 +553,40 @@ def visit_Attribute(self, node):
(node.lineno, query_name)
)

def visit_Num(self, node):
if (
node.n in
self.tractography_spatial_indexing.crossing_labels_tracts
):
tracts = (
self.tractography_spatial_indexing.
crossing_labels_tracts[node.n]
def visit_Constant(self, node):
if isinstance(node.value, numbers.Number):
if (
node.n in
self.tractography_spatial_indexing.crossing_labels_tracts
):
tracts = (
self.tractography_spatial_indexing.
crossing_labels_tracts[node.n]
)
else:
tracts = set()

endpoints = (set(), set())
for i in (0, 1):
elt = self.tractography_spatial_indexing.ending_labels_tracts[i]
if node.n in elt:
endpoints[i].update(elt[node.n])

labelset = set((node.n,))
query_info = FiberQueryInfo(
tracts, labelset,
endpoints
)
else:
tracts = set()

endpoints = (set(), set())
for i in (0, 1):
elt = self.tractography_spatial_indexing.ending_labels_tracts[i]
if node.n in elt:
endpoints[i].update(elt[node.n])

labelset = set((node.n,))
tract_info = FiberQueryInfo(
tracts, labelset,
endpoints
)
elif isinstance(node.value, str):
query_info = FiberQueryInfo()
for name in fnmatch.filter(self.evaluated_queries_info.keys(),
node.s):
query_info.update(self.evaluated_queries_info[name])
else:
raise NotImplementedError(f"{node.value} not supported.")

return tract_info
return query_info

def visit_Expr(self, node):
if isinstance(node.value, ast.Name):
Expand Down Expand Up @@ -735,11 +739,14 @@ def visit_Name(self, node):
node
)

def visit_Str(self, node):
return ast.copy_location(
ast.Str(s=node.s.lower()),
node
)
def visit_Constant(self, node):
if isinstance(node.s, str):
return ast.copy_location(
ast.Constant(node.s.lower()),
node
)
else:
return self.generic_visit(node)

def visit_Import(self, node):
try:
Expand Down
Loading