Skip to content

Commit

Permalink
ENH: Remove dependency on nibabel.trackvis
Browse files Browse the repository at this point in the history
Remove dependency on `nibabel.trackvis`: the module was deprecated in
`NiBabel` 2.5.0 and removed in version 4.0.0. So this patch set adapts
the `tractography.trackvis` code so that it uses the modern API.

Fixes:
```
tract_querier/tractography/tests/test_tractography.py::test_saveload_trk
  site-packages/nibabel/deprecated.py:35:
 DeprecationWarning:
 The trackvis interface has been deprecated and will be removed in v4.0; please use the 'nibabel.streamlines' interface.
    mod = __import__(self._module_name, fromlist=[''])
```

and
```
tract_querier/tractography/trackvis.py:7: in <module>
    from nibabel import trackvis
E   ImportError: cannot import name 'trackvis' from 'nibabel'
 (/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/nibabel/__init__.py)
```

and related errors raised for example in:
https://github.com/demianw/tract_querier/actions/runs/12195905001/job/34022456528#step:6:85
and
https://github.com/demianw/tract_querier/actions/runs/12195905001/job/34022456201#step:6:32
  • Loading branch information
jhlegarreta committed Dec 9, 2024
1 parent 93ab73f commit d8b7f2c
Showing 1 changed file with 153 additions and 2 deletions.
155 changes: 153 additions & 2 deletions tract_querier/tractography/trackvis.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import os
from warnings import warn

import numpy

from .tractography import Tractography

import nibabel as nib
from nibabel import trackvis


def tractography_to_trackvis_file(filename, tractography, affine=None, image_dimensions=None):
trk_header = trackvis.empty_header()
trk_header_new_nbb = nib.streamlines.TrkFile.create_empty_header()

if affine is not None:
pass
Expand All @@ -18,11 +21,16 @@ def tractography_to_trackvis_file(filename, tractography, affine=None, image_dim
raise ValueError("Affine transform has to be provided")

trackvis.aff_to_hdr(affine, trk_header, True, True)
trk_header_new_nbb['vox_to_ras'] = affine

trk_header['origin'] = 0.
trk_header_new_nbb['origin'] = 0.
if image_dimensions is not None:
trk_header['dim'] = image_dimensions
trk_header_new_nbb['dimensions'] = image_dimensions
elif hasattr(tractography, 'image_dimensions'):
trk_header['dim'] = tractography.image_dimensions
trk_header_new_nbb['dimensions'] = tractography.image_dimensions
else:
raise ValueError("Image dimensions needed to save a trackvis file")

Expand All @@ -49,8 +57,11 @@ def tractography_to_trackvis_file(filename, tractography, affine=None, image_dim
# else:
# data_new[k] = v
trk_header['n_count'] = len(tractography.tracts())
trk_header_new_nbb['nb_streamlines'] = len(tractography.tracts())
trk_header['n_properties'] = 0
trk_header_new_nbb['nb_properties_per_streamline'] = 0
trk_header['n_scalars'] = len(data)
trk_header_new_nbb['nb_scalars_per_point'] = len(data)

if len(data) > 10:
raise ValueError('At most 10 scalars permitted per point')
Expand All @@ -59,6 +70,10 @@ def tractography_to_trackvis_file(filename, tractography, affine=None, image_dim
[n[:20] for n in data],
dtype='|S20'
)
trk_header_new_nbb['scalar_name'][:len(data)] = numpy.array(
[n[:20] for n in data],
dtype='|S20'
)
trk_tracts = []

for i, sl in enumerate(tractography.tracts()):
Expand All @@ -73,12 +88,70 @@ def tractography_to_trackvis_file(filename, tractography, affine=None, image_dim

trackvis.write(filename, trk_tracts, trk_header, points_space='rasmm')

trk_header_new_nbb['image_orientation_patient'] = trk_header['image_orientation_patient']

affine_voxmm_to_rasmm = nib.streamlines.trk.get_affine_trackvis_to_rasmm(trk_header_new_nbb)

# Save the data using the new NiBabel interface
streamlines = tractography.tracts()
data_per_streamline = None
data_per_point = data
tractogram_nb = nib.streamlines.Tractogram(
streamlines=streamlines,
data_per_streamline=data_per_streamline,
data_per_point=data_per_point,
affine_to_rasmm=affine_voxmm_to_rasmm,
)
path = os.path.dirname(filename)
file_basename = os.path.basename(filename).split(".")[0] + "_new_nbb.trk"
filename_new_nbb = os.path.join(path, file_basename)
nib.streamlines.save(tractogram_nb, filename_new_nbb, header=trk_header_new_nbb)

# Read the tractograms back and compare
tracts_and_data, header_rt = trackvis.read(filename, points_space="rasmm")
tracts_trackvis, scalars, properties = list(zip(*tracts_and_data))

trk_file = nib.streamlines.load(filename_new_nbb)

# Check the scalars (data per point)
tract_data_original = tractography.tracts_data()
tract_data_trackvis = tractography_scalars_to_dpp(scalars, header_rt)
tract_data_nibabel = get_nibabel_trk_dpp(trk_file)

assert_dpp_equality(tract_data_original, tract_data_trackvis)
assert_dpp_equality(tract_data_original, tract_data_nibabel)

# Check the tractography data
strml_original = nib.streamlines.ArraySequence(tractography.tracts())
strml_trackvis = nib.streamlines.ArraySequence(tracts_trackvis)

strml_nibabel = trk_file.streamlines # same as trk_file.tractogram.streamlines
aff = nib.streamlines.trk.get_affine_trackvis_to_rasmm(trk_file.header)
strml_data_original_aff = nib.affines.apply_affine(aff, strml_original.get_data())

assert numpy.allclose(strml_original.get_data(), strml_trackvis.get_data())
assert numpy.allclose(strml_nibabel.get_data(), strml_data_original_aff, atol=1e-7)


def tractography_from_trackvis_file(filename):
tracts_and_data, header = trackvis.read(filename, points_space='rasmm')
trk_file = nib.streamlines.load(filename)

tracts, scalars, properties = list(zip(*tracts_and_data))

# Check the tractography data
strml_trackvis = nib.streamlines.ArraySequence(tracts)

strml_nibabel = trk_file.streamlines # same as trk_file.tractogram.streamlines
aff = nib.streamlines.trk.get_affine_trackvis_to_rasmm(trk_file.header)
strml_data_trackvis_aff = nib.affines.apply_affine(aff, strml_trackvis.get_data())

assert numpy.allclose(strml_nibabel.get_data(), strml_data_trackvis_aff, atol=1e-7)

# Equivalently
strml_nibabel_aff = nib.affines.apply_affine(numpy.linalg.inv(aff), strml_nibabel.get_data())
assert numpy.allclose(strml_nibabel_aff, strml_trackvis.get_data(), atol=1e-7)

scalar_names = [n for n in header['scalar_name'] if len(n) > 0]

#scalar_names_unique = []
Expand All @@ -102,9 +175,87 @@ def tractography_from_trackvis_file(filename):
affine = header['vox_to_ras']
image_dims = header['dim']

affine_nb = trk_file.header['voxel_to_rasmm']
image_dims_nb = trk_file.header['dimensions']

# Check some header data
assert numpy.allclose(affine, affine_nb)
assert numpy.allclose(image_dims, image_dims_nb)

# Check the scalars (data per point)
tracts_data_store_nb = trk_file.tractogram.data_per_point.store
dpp = {k: array_sequence_to_dpp(v) for k, v in tracts_data_store_nb.items()}
assert [numpy.allclose(lh_item, rh_item) for lh_list, rh_list in
zip(tracts_data.values(), dpp.values()) for lh_item, rh_item in
zip(lh_list, rh_list)]

# tr = Tractography(
# tracts, tracts_data,
# affine=affine, image_dims=image_dims
#)

# This should build an exactly equivalent Tractography instance
tracts_nibabel = array_sequence_data_to_tracts(
strml_nibabel_aff, strml_nibabel._offsets, strml_nibabel._lengths)
tr = Tractography(
tracts, tracts_data,
affine=affine, image_dims=image_dims
tracts_nibabel, dpp,
affine=affine_nb, image_dims=image_dims_nb
)

return tr


def array_sequence_to_dpp(array_seq):

dpp = []
for offset, length in zip(array_seq._offsets, array_seq._lengths):
val = array_seq._data[offset: offset + length]
dpp.append(val)

return dpp


def array_sequence_data_to_tracts(array_seq_data, offsets, lengths):

tracts = []
for offset, length in zip(offsets, lengths):
val = array_seq_data[offset: offset + length]
tracts.append(val)

return tuple(tracts)


def get_scalar_names_from_trackvis_header(header):

return [x.decode() for x in header['scalar_name'] if isinstance(x, (bytes, bytearray)) and x]


def tractography_scalars_to_dpp(scalars, header):

import numpy as np
scalar_names = get_scalar_names_from_trackvis_header(header)

scalar_count = scalars[0].shape[-1]
assert len(scalar_names) == scalar_count
assert all(s.shape[-1] == scalar_count for s in scalars)

dpp = {elem: [] for elem in scalar_names}

for array in scalars:
for i, elem in enumerate(scalar_names):
dpp[elem].append(array[:, i].reshape(-1, 1))

return dpp


def get_nibabel_trk_dpp(trk_file):

_tract_data = trk_file.tractogram.data_per_point.store
return {k: array_sequence_to_dpp(v) for k, v in _tract_data.items()}


def assert_dpp_equality(lh, rh):

for (k1, v1), (k2, v2) in zip(lh.items(), rh.items()):
assert len(v1) == len(v2)
assert k1 == k2 and all([numpy.allclose(v1[i], v2[i]) for i in range(len(v1))])

0 comments on commit d8b7f2c

Please sign in to comment.