Skip to content

Commit

Permalink
enh: add test comparing bsplines and displacements fields
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Feb 23, 2022
1 parent cc2ea73 commit a467b65
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 33 deletions.
72 changes: 44 additions & 28 deletions nitransforms/nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,21 @@ class DisplacementsFieldTransform(TransformBase):
__slots__ = ["_field"]

def __init__(self, field, reference=None):
"""Create a dense deformation field transform."""
"""
Create a dense deformation field transform.
Example
-------
>>> DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
<(57, 67, 56) field of 3D displacements>
"""
super().__init__()

self._field = np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
field = _ensure_image(field)
self._field = np.squeeze(
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
)
self.reference = reference or field.__class__(
np.zeros(self._field.shape[:-1]), field.affine, field.header
)
Expand All @@ -46,6 +57,10 @@ def __init__(self, field, reference=None):
"the number of dimensions (%d)" % (self._field.shape[-1], ndim)
)

def __repr__(self):
"""Beautify the python representation."""
return f"<{self._field.shape[:3]} field of {self._field.shape[-1]}D displacements>"

def map(self, x, inverse=False):
r"""
Apply the transformation to a list of physical coordinate points.
Expand All @@ -71,15 +86,12 @@ def map(self, x, inverse=False):
Examples
--------
>>> field = np.zeros((10, 10, 10, 3))
>>> field[..., 0] = 4.0
>>> fieldimg = nb.Nifti1Image(field, np.diag([2., 2., 2., 1.]))
>>> xfm = DisplacementsFieldTransform(fieldimg)
>>> xfm([4.0, 4.0, 4.0]).tolist()
[[8.0, 4.0, 4.0]]
>>> xfm = DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
>>> xfm.map([-6.5, -36., -19.5]).tolist()
[[-6.5, -36.475167989730835, -19.5]]
>>> xfm([[4.0, 4.0, 4.0], [8, 2, 10]]).tolist()
[[8.0, 4.0, 4.0], [12.0, 2.0, 10.0]]
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
[[-6.5, -36.475167989730835, -19.5], [-1.0, -42.038356602191925, -11.25]]
"""
if inverse is True:
Expand Down Expand Up @@ -112,26 +124,31 @@ class BSplineFieldTransform(TransformBase):

__slots__ = ['_coeffs', '_knots', '_weights', '_order', '_moving']

def __init__(self, reference, coefficients, order=3):
def __init__(self, coefficients, reference=None, order=3):
"""Create a smooth deformation field using B-Spline basis."""
super(BSplineFieldTransform, self).__init__()
self._order = order
self.reference = reference

coefficients = _ensure_image(coefficients)
if coefficients.shape[-1] != self.ndim:
raise ValueError(
'Number of components of the coefficients does '
'not match the number of dimensions')

self._coeffs = np.asanyarray(coefficients.dataobj)
self._knots = ImageGrid(four_to_three(coefficients)[0])
self._weights = None
if reference is not None:
self.reference = reference

if coefficients.shape[-1] != self.ndim:
raise ValueError(
'Number of components of the coefficients does '
'not match the number of dimensions')

def to_field(self, reference=None):
def to_field(self, reference=None, dtype="float32"):
"""Generate a displacements deformation field from this B-Spline field."""
reference = _ensure_image(reference)
_ref = self.reference if reference is None else SpatialReference.factory(reference)
if _ref is None:
raise ValueError("A reference must be defined")

ndim = self._coeffs.shape[-1]

# If locations to be interpolated are on a grid, use faster tensor-bspline calculation
Expand All @@ -143,7 +160,7 @@ def to_field(self, reference=None):
for d in range(ndim):
field[:, d] = self._coeffs[..., d].reshape(-1) @ self._weights

return field.astype("float32")
return field.astype(dtype)

def apply(
self,
Expand Down Expand Up @@ -215,23 +232,22 @@ def map(self, x, inverse=False):
Examples
--------
>>> field = np.zeros((10, 10, 10, 3))
>>> field[..., 0] = 4.0
>>> fieldimg = nb.Nifti1Image(field, np.diag([2., 2., 2., 1.]))
>>> xfm = DisplacementsFieldTransform(fieldimg)
>>> xfm([4.0, 4.0, 4.0]).tolist()
[[8.0, 4.0, 4.0]]
>>> xfm([[4.0, 4.0, 4.0], [8, 2, 10]]).tolist()
[[8.0, 4.0, 4.0], [12.0, 2.0, 10.0]]
>>> xfm = BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz")
>>> xfm.reference = test_dir / "someones_anatomy.nii.gz"
>>> xfm.map([-6.5, -36., -19.5]).tolist()
[[-6.5, -31.476097418406784, -19.5]]
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
[[-6.5, -31.476097418406784, -19.5], [-1.0, -3.8072675377121996, -11.25]]
"""
vfunc = partial(
_map_xyz,
reference=self.reference,
knots=self._knots,
coeffs=self._coeffs,
)
return [vfunc(_x) for _x in np.atleast_2d(x)]
return np.array([vfunc(_x).tolist() for _x in np.atleast_2d(x)])


def _map_xyz(x, reference, knots, coeffs):
Expand Down
36 changes: 31 additions & 5 deletions nitransforms/tests/test_nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import numpy as np
import nibabel as nb
from ..io.base import TransformFileError
from ..nonlinear import DisplacementsFieldTransform, load as nlload
from ..nonlinear import (
BSplineFieldTransform,
DisplacementsFieldTransform,
load as nlload,
)
from ..io.itk import ITKDisplacementsField


Expand All @@ -33,21 +37,21 @@
def test_itk_disp_load(size):
"""Checks field sizes."""
with pytest.raises(TransformFileError):
ITKDisplacementsField.from_image(nb.Nifti1Image(np.zeros(size), None, None))
ITKDisplacementsField.from_image(nb.Nifti1Image(np.zeros(size), np.eye(4), None))


@pytest.mark.parametrize("size", [(20, 20, 20), (20, 20, 20, 1, 3)])
@pytest.mark.parametrize("size", [(20, 20, 20), (20, 20, 20, 2, 3)])
def test_displacements_bad_sizes(size):
"""Checks field sizes."""
with pytest.raises(ValueError):
DisplacementsFieldTransform(nb.Nifti1Image(np.zeros(size), None, None))
DisplacementsFieldTransform(nb.Nifti1Image(np.zeros(size), np.eye(4), None))


def test_itk_disp_load_intent():
"""Checks whether the NIfTI intent is fixed."""
with pytest.warns(UserWarning):
field = ITKDisplacementsField.from_image(
nb.Nifti1Image(np.zeros((20, 20, 20, 1, 3)), None, None)
nb.Nifti1Image(np.zeros((20, 20, 20, 1, 3)), np.eye(4), None)
)

assert field.header.get_intent()[0] == "vector"
Expand Down Expand Up @@ -177,3 +181,25 @@ def test_displacements_field2(tmp_path, testdata_path, sw_tool):
)
# A certain tolerance is necessary because of resampling at borders
assert np.sqrt((diff ** 2).mean()) < RMSE_TOL


def test_bspline(tmp_path, testdata_path):
"""Cross-check B-Splines and deformation field."""
os.chdir(str(tmp_path))

img_name = testdata_path / "someones_anatomy.nii.gz"
disp_name = testdata_path / "someones_displacement_field.nii.gz"
bs_name = testdata_path / "someones_bspline_coefficients.nii.gz"

bsplxfm = BSplineFieldTransform(bs_name, reference=img_name)
dispxfm = DisplacementsFieldTransform(disp_name)

out_disp = dispxfm.apply(img_name)
out_bspl = bsplxfm.apply(img_name)

out_disp.to_filename("resampled_field.nii.gz")
out_bspl.to_filename("resampled_bsplines.nii.gz")

assert np.sqrt(
(out_disp.get_fdata(dtype="float32") - out_bspl.get_fdata(dtype="float32")) ** 2
).mean() < 0.2

0 comments on commit a467b65

Please sign in to comment.