diff --git a/tract_querier/tractography/tests/test_tractography.py b/tract_querier/tractography/tests/test_tractography.py index b842b12..9c048e2 100644 --- a/tract_querier/tractography/tests/test_tractography.py +++ b/tract_querier/tractography/tests/test_tractography.py @@ -30,7 +30,7 @@ def equal_tracts(a, b): for t1, t2 in zip(a, b): - if not (len(t1) == len(t2) and allclose(t1, t2)): + if not (len(t1) == len(t2) and allclose(t1, t2, atol=1e-7)): return False return True diff --git a/tract_querier/tractography/trackvis.py b/tract_querier/tractography/trackvis.py index 9398632..17d09ac 100644 --- a/tract_querier/tractography/trackvis.py +++ b/tract_querier/tractography/trackvis.py @@ -1,14 +1,21 @@ +import os from warnings import warn import numpy from .tractography import Tractography -from nibabel import trackvis +import nibabel as nib def tractography_to_trackvis_file(filename, tractography, affine=None, image_dimensions=None): - trk_header = trackvis.empty_header() + + # The below could have used + # https://github.com/nipy/nibabel/blob/3.0.0/nibabel/streamlines/trk.py#L226 + # trk_file = TrkFile(tractogram, header) + # trk_file.save(filename) + + trk_header_new_nbb = nib.streamlines.TrkFile.create_empty_header() if affine is not None: pass @@ -17,12 +24,13 @@ def tractography_to_trackvis_file(filename, tractography, affine=None, image_dim else: raise ValueError("Affine transform has to be provided") - trackvis.aff_to_hdr(affine, trk_header, True, True) - trk_header['origin'] = 0. + trk_header_new_nbb["vox_to_ras"] = affine + + trk_header_new_nbb["origin"] = 0. if image_dimensions is not None: - trk_header['dim'] = image_dimensions - elif hasattr(tractography, 'image_dimensions'): - trk_header['dim'] = tractography.image_dimensions + trk_header_new_nbb["dimensions"] = image_dimensions + elif hasattr(tractography, "image_dimensions"): + trk_header_new_nbb['dimensions'] = tractography.image_dimensions else: raise ValueError("Image dimensions needed to save a trackvis file") @@ -39,72 +47,100 @@ def tractography_to_trackvis_file(filename, tractography, affine=None, image_dim else: data[k] = v - #data_new = {} - # for k, v in data.iteritems(): - # if (v[0].ndim > 1 and v[0].shape[1] > 1): - # for i in range(v[0].shape[1]): - # data_new['%s_%02d' % (k, i)] = [ - # v_[:, i] for v_ in v - # ] - # else: - # data_new[k] = v - trk_header['n_count'] = len(tractography.tracts()) - trk_header['n_properties'] = 0 - trk_header['n_scalars'] = len(data) + trk_header_new_nbb["nb_streamlines"] = len(tractography.tracts()) + trk_header_new_nbb["nb_properties_per_streamline"] = 0 + trk_header_new_nbb["nb_scalars_per_point"] = len(data) if len(data) > 10: raise ValueError('At most 10 scalars permitted per point') - trk_header['scalar_name'][:len(data)] = numpy.array( + trk_header_new_nbb["scalar_name"][:len(data)] = numpy.array( [n[:20] for n in data], - dtype='|S20' + dtype="|S20" ) - trk_tracts = [] - for i, sl in enumerate(tractography.tracts()): - scalars = None - if len(data) > 0: - scalars = numpy.vstack([ - data[k.decode('utf8')][i].squeeze() - for k in trk_header['scalar_name'][:len(data)] - ]).T + trk_header_new_nbb["image_orientation_patient"] = compute_image_orientation_patient(affine, True, True) - trk_tracts.append((sl, scalars, None)) + affine_voxmm_to_rasmm = nib.streamlines.trk.get_affine_trackvis_to_rasmm(trk_header_new_nbb) - trackvis.write(filename, trk_tracts, trk_header, points_space='rasmm') + streamlines = tractography.tracts() + data_per_streamline = None + data_per_point = data + tractogram_nb = nib.streamlines.Tractogram( + streamlines=streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point, + affine_to_rasmm=affine_voxmm_to_rasmm, + ) + nib.streamlines.save(tractogram_nb, filename, header=trk_header_new_nbb) def tractography_from_trackvis_file(filename): - tracts_and_data, header = trackvis.read(filename, points_space='rasmm') - tracts, scalars, properties = list(zip(*tracts_and_data)) + trk_file = nib.streamlines.load(filename) - scalar_names = [n for n in header['scalar_name'] if len(n) > 0] + strml_nibabel = trk_file.streamlines # same as trk_file.tractogram.streamlines + affine = nib.streamlines.trk.get_affine_trackvis_to_rasmm(trk_file.header) - #scalar_names_unique = [] - #scalar_names_subcomp = {} - # for sn in scalar_names: - # if re.match('.*_[0-9]{2}', sn): - # prefix = sn[:sn.rfind('_')] - # if prefix not in scalar_names_unique: - # scalar_names_unique.append(prefix) - # scalar_names_subcomp[prefix] = int(sn[-2:]) - # scalar_names_subcomp[prefix] = max(sn[-2:], scalar_names_subcomp[prefix]) - # else: - # scalar_names_unique.append(sn) + strml_nibabel_aff = nib.affines.apply_affine( + numpy.linalg.inv(affine), strml_nibabel.get_data()) - tracts_data = {} - for i, sn in enumerate(scalar_names): - if hasattr(sn, 'decode'): - sn = sn.decode() - tracts_data[sn] = [scalar[:, i][:, None] for scalar in scalars] + affine_nb = trk_file.header["voxel_to_rasmm"] + image_dims_nb = trk_file.header["dimensions"] - affine = header['vox_to_ras'] - image_dims = header['dim'] + tracts_data_store_nb = trk_file.tractogram.data_per_point.store + dpp = {k: array_sequence_to_dpp(v) for k, v in tracts_data_store_nb.items()} + tracts_nibabel = array_sequence_data_to_tracts( + strml_nibabel_aff, strml_nibabel._offsets, strml_nibabel._lengths) tr = Tractography( - tracts, tracts_data, - affine=affine, image_dims=image_dims + tracts_nibabel, dpp, + affine=affine_nb, image_dims=image_dims_nb ) return tr + + +def compute_image_orientation_patient(affine, pos_vox, set_order): + + # Borrowed from + # https://github.com/nipy/nibabel/blob/3.0.0/nibabel/trackvis.py#L685 + + affine = numpy.dot(numpy.diag([-1, -1, 1, 1]), affine) + # trans = affine[:3, 3] + # Get zooms + RZS = affine[:3, :3] + zooms = numpy.sqrt(numpy.sum(RZS * RZS, axis=0)) + RS = RZS / zooms + # If you said we could, adjust zooms to make RS correspond (below) to a + # true rotation matrix. We need to set the sign of one of the zooms to + # deal with this. Trackvis (the application) doesn't like negative zooms + # at all, so you might want to disallow this with the pos_vox option. + if not pos_vox and numpy.linalg.det(RS) < 0: + zooms[0] *= -1 + RS[:, 0] *= -1 + # retrieve rotation matrix from RS with polar decomposition. + # Discard shears because we cannot store them. + P, S, Qs = numpy.linalg.svd(RS) + R = numpy.dot(P, Qs) + return R[:, 0:2].T.ravel() + + +def array_sequence_to_dpp(array_seq): + + dpp = [] + for offset, length in zip(array_seq._offsets, array_seq._lengths): + val = array_seq._data[offset: offset + length] + dpp.append(val) + + return dpp + + +def array_sequence_data_to_tracts(array_seq_data, offsets, lengths): + + tracts = [] + for offset, length in zip(offsets, lengths): + val = array_seq_data[offset: offset + length] + tracts.append(val) + + return tuple(tracts)