Skip to content

Commit

Permalink
ENH: Remove dependency on nibabel.trackvis
Browse files Browse the repository at this point in the history
Remove dependency on `nibabel.trackvis`: the module was deprecated in
`NiBabel` 2.5.0 and removed in version 4.0.0. So this patch set adapts
the `tractography.trackvis` code so that it uses the modern API.

Add a tolerance when comparing streamline data in `equal_tracts` to
account for some precision loss since data is transformed to `float32`
(vs `float64`) when applying the operations of the `NiBabel` API.

Fixes:
```
tract_querier/tractography/tests/test_tractography.py::test_saveload_trk
  site-packages/nibabel/deprecated.py:35:
 DeprecationWarning:
 The trackvis interface has been deprecated and will be removed in v4.0; please use the 'nibabel.streamlines' interface.
    mod = __import__(self._module_name, fromlist=[''])
```

and
```
tract_querier/tractography/trackvis.py:7: in <module>
    from nibabel import trackvis
E   ImportError: cannot import name 'trackvis' from 'nibabel'
 (/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/nibabel/__init__.py)
```

and related errors raised for example in:
https://github.com/demianw/tract_querier/actions/runs/12195905001/job/34022456528#step:6:85
and
https://github.com/demianw/tract_querier/actions/runs/12195905001/job/34022456201#step:6:32
  • Loading branch information
jhlegarreta committed Dec 10, 2024
1 parent 93ab73f commit c977552
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 55 deletions.
2 changes: 1 addition & 1 deletion tract_querier/tractography/tests/test_tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

def equal_tracts(a, b):
for t1, t2 in zip(a, b):
if not (len(t1) == len(t2) and allclose(t1, t2)):
if not (len(t1) == len(t2) and allclose(t1, t2, atol=1e-7)):
return False

return True
Expand Down
144 changes: 90 additions & 54 deletions tract_querier/tractography/trackvis.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import os
from warnings import warn

import numpy

from .tractography import Tractography

from nibabel import trackvis
import nibabel as nib


def tractography_to_trackvis_file(filename, tractography, affine=None, image_dimensions=None):
trk_header = trackvis.empty_header()

# The below could have used
# https://github.com/nipy/nibabel/blob/3.0.0/nibabel/streamlines/trk.py#L226
# trk_file = TrkFile(tractogram, header)
# trk_file.save(filename)

trk_header_new_nbb = nib.streamlines.TrkFile.create_empty_header()

if affine is not None:
pass
Expand All @@ -17,12 +24,13 @@ def tractography_to_trackvis_file(filename, tractography, affine=None, image_dim
else:
raise ValueError("Affine transform has to be provided")

trackvis.aff_to_hdr(affine, trk_header, True, True)
trk_header['origin'] = 0.
trk_header_new_nbb["vox_to_ras"] = affine

trk_header_new_nbb["origin"] = 0.
if image_dimensions is not None:
trk_header['dim'] = image_dimensions
elif hasattr(tractography, 'image_dimensions'):
trk_header['dim'] = tractography.image_dimensions
trk_header_new_nbb["dimensions"] = image_dimensions
elif hasattr(tractography, "image_dimensions"):
trk_header_new_nbb['dimensions'] = tractography.image_dimensions
else:
raise ValueError("Image dimensions needed to save a trackvis file")

Expand All @@ -39,72 +47,100 @@ def tractography_to_trackvis_file(filename, tractography, affine=None, image_dim
else:
data[k] = v

#data_new = {}
# for k, v in data.iteritems():
# if (v[0].ndim > 1 and v[0].shape[1] > 1):
# for i in range(v[0].shape[1]):
# data_new['%s_%02d' % (k, i)] = [
# v_[:, i] for v_ in v
# ]
# else:
# data_new[k] = v
trk_header['n_count'] = len(tractography.tracts())
trk_header['n_properties'] = 0
trk_header['n_scalars'] = len(data)
trk_header_new_nbb["nb_streamlines"] = len(tractography.tracts())
trk_header_new_nbb["nb_properties_per_streamline"] = 0
trk_header_new_nbb["nb_scalars_per_point"] = len(data)

if len(data) > 10:
raise ValueError('At most 10 scalars permitted per point')

trk_header['scalar_name'][:len(data)] = numpy.array(
trk_header_new_nbb["scalar_name"][:len(data)] = numpy.array(
[n[:20] for n in data],
dtype='|S20'
dtype="|S20"
)
trk_tracts = []

for i, sl in enumerate(tractography.tracts()):
scalars = None
if len(data) > 0:
scalars = numpy.vstack([
data[k.decode('utf8')][i].squeeze()
for k in trk_header['scalar_name'][:len(data)]
]).T
trk_header_new_nbb["image_orientation_patient"] = compute_image_orientation_patient(affine, True, True)

trk_tracts.append((sl, scalars, None))
affine_voxmm_to_rasmm = nib.streamlines.trk.get_affine_trackvis_to_rasmm(trk_header_new_nbb)

trackvis.write(filename, trk_tracts, trk_header, points_space='rasmm')
streamlines = tractography.tracts()
data_per_streamline = None
data_per_point = data
tractogram_nb = nib.streamlines.Tractogram(
streamlines=streamlines,
data_per_streamline=data_per_streamline,
data_per_point=data_per_point,
affine_to_rasmm=affine_voxmm_to_rasmm,
)
nib.streamlines.save(tractogram_nb, filename, header=trk_header_new_nbb)


def tractography_from_trackvis_file(filename):
tracts_and_data, header = trackvis.read(filename, points_space='rasmm')

tracts, scalars, properties = list(zip(*tracts_and_data))
trk_file = nib.streamlines.load(filename)

scalar_names = [n for n in header['scalar_name'] if len(n) > 0]
strml_nibabel = trk_file.streamlines # same as trk_file.tractogram.streamlines
affine = nib.streamlines.trk.get_affine_trackvis_to_rasmm(trk_file.header)

#scalar_names_unique = []
#scalar_names_subcomp = {}
# for sn in scalar_names:
# if re.match('.*_[0-9]{2}', sn):
# prefix = sn[:sn.rfind('_')]
# if prefix not in scalar_names_unique:
# scalar_names_unique.append(prefix)
# scalar_names_subcomp[prefix] = int(sn[-2:])
# scalar_names_subcomp[prefix] = max(sn[-2:], scalar_names_subcomp[prefix])
# else:
# scalar_names_unique.append(sn)
strml_nibabel_aff = nib.affines.apply_affine(
numpy.linalg.inv(affine), strml_nibabel.get_data())

tracts_data = {}
for i, sn in enumerate(scalar_names):
if hasattr(sn, 'decode'):
sn = sn.decode()
tracts_data[sn] = [scalar[:, i][:, None] for scalar in scalars]
affine_nb = trk_file.header["voxel_to_rasmm"]
image_dims_nb = trk_file.header["dimensions"]

affine = header['vox_to_ras']
image_dims = header['dim']
tracts_data_store_nb = trk_file.tractogram.data_per_point.store
dpp = {k: array_sequence_to_dpp(v) for k, v in tracts_data_store_nb.items()}

tracts_nibabel = array_sequence_data_to_tracts(
strml_nibabel_aff, strml_nibabel._offsets, strml_nibabel._lengths)
tr = Tractography(
tracts, tracts_data,
affine=affine, image_dims=image_dims
tracts_nibabel, dpp,
affine=affine_nb, image_dims=image_dims_nb
)

return tr


def compute_image_orientation_patient(affine, pos_vox, set_order):

# Borrowed from
# https://github.com/nipy/nibabel/blob/3.0.0/nibabel/trackvis.py#L685

affine = numpy.dot(numpy.diag([-1, -1, 1, 1]), affine)
# trans = affine[:3, 3]
# Get zooms
RZS = affine[:3, :3]
zooms = numpy.sqrt(numpy.sum(RZS * RZS, axis=0))
RS = RZS / zooms
# If you said we could, adjust zooms to make RS correspond (below) to a
# true rotation matrix. We need to set the sign of one of the zooms to
# deal with this. Trackvis (the application) doesn't like negative zooms
# at all, so you might want to disallow this with the pos_vox option.
if not pos_vox and numpy.linalg.det(RS) < 0:
zooms[0] *= -1
RS[:, 0] *= -1
# retrieve rotation matrix from RS with polar decomposition.
# Discard shears because we cannot store them.
P, S, Qs = numpy.linalg.svd(RS)
R = numpy.dot(P, Qs)
return R[:, 0:2].T.ravel()


def array_sequence_to_dpp(array_seq):

dpp = []
for offset, length in zip(array_seq._offsets, array_seq._lengths):
val = array_seq._data[offset: offset + length]
dpp.append(val)

return dpp


def array_sequence_data_to_tracts(array_seq_data, offsets, lengths):

tracts = []
for offset, length in zip(offsets, lengths):
val = array_seq_data[offset: offset + length]
tracts.append(val)

return tuple(tracts)

0 comments on commit c977552

Please sign in to comment.