Skip to content

Commit

Permalink
fix: extend tests and normalize error types
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Mar 2, 2022
1 parent b692390 commit 1fa7eeb
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 32 deletions.
3 changes: 2 additions & 1 deletion nitransforms/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Read and write transforms."""
from nitransforms.io import afni, fsl, itk, lta
from nitransforms.io.base import TransformFileError
from nitransforms.io.base import TransformIOError, TransformFileError

__all__ = [
"afni",
Expand All @@ -11,6 +11,7 @@
"lta",
"get_linear_factory",
"TransformFileError",
"TransformIOError",
]

_IO_TYPES = {
Expand Down
14 changes: 9 additions & 5 deletions nitransforms/io/afni.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,16 @@ def from_string(cls, string):
if not lines:
raise TransformFileError

parameters = np.vstack(
(
np.genfromtxt([lines[0].encode()], dtype="f8").reshape((3, 4)),
(0.0, 0.0, 0.0, 1.0),
try:
parameters = np.vstack(
(
np.genfromtxt([lines[0].encode()], dtype="f8").reshape((3, 4)),
(0.0, 0.0, 0.0, 1.0),
)
)
)
except ValueError as e:
raise TransformFileError from e

sa["parameters"] = parameters
return tf

Expand Down
8 changes: 6 additions & 2 deletions nitransforms/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
from ..patched import LabeledWrapStruct


class TransformFileError(Exception):
"""A custom exception for transform files."""
class TransformIOError(IOError):
"""General I/O exception while reading/writing transforms."""


class TransformFileError(TransformIOError):
"""Specific I/O exception when a file does not meet the expected format."""


class StringBasedStruct(LabeledWrapStruct):
Expand Down
5 changes: 3 additions & 2 deletions nitransforms/io/fsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BaseLinearTransformList,
LinearParameters,
DisplacementsField,
TransformIOError,
TransformFileError,
_ensure_image,
)
Expand Down Expand Up @@ -40,7 +41,7 @@ def from_ras(cls, ras, moving=None, reference=None):
moving = reference

if reference is None:
raise ValueError("Cannot build FSL linear transform without a reference")
raise TransformIOError("Cannot build FSL linear transform without a reference")

reference = _ensure_image(reference)
moving = _ensure_image(moving)
Expand Down Expand Up @@ -77,7 +78,7 @@ def from_string(cls, string):
def to_ras(self, moving=None, reference=None):
"""Return a nitransforms internal RAS+ matrix."""
if reference is None:
raise ValueError("Cannot build FSL linear transform without a reference")
raise TransformIOError("Cannot build FSL linear transform without a reference")

if moving is None:
warnings.warn(
Expand Down
5 changes: 3 additions & 2 deletions nitransforms/io/itk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
BaseLinearTransformList,
DisplacementsField,
LinearParameters,
TransformIOError,
TransformFileError,
)

Expand Down Expand Up @@ -306,7 +307,7 @@ def from_filename(cls, filename):
from h5py import File as H5File

if not str(filename).endswith(".h5"):
raise RuntimeError("Extension is not .h5")
raise TransformFileError("Extension is not .h5")

with H5File(str(filename)) as f:
return cls.from_h5obj(f)
Expand Down Expand Up @@ -355,7 +356,7 @@ def from_h5obj(cls, fileobj, check=True):
)
continue

raise NotImplementedError(
raise TransformIOError(
f"Unsupported transform type {xfm['TransformType'][0]}"
)

Expand Down
8 changes: 4 additions & 4 deletions nitransforms/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,11 @@ def from_filename(cls, filename, fmt=None, reference=None, moving=None):
raise TypeError("Cannot load transform array '%s'" % filename)
matrix = matrix[0]
return cls(matrix, reference=reference)
except TransformFileError:
except (TransformFileError, FileNotFoundError):
continue

raise TransformFileError(
f"Could not open <{filename}> (formats tried: {', '.join(fmtlist)}."
f"Could not open <{filename}> (formats tried: {', '.join(fmtlist)})."
)

def __repr__(self):
Expand Down Expand Up @@ -468,11 +468,11 @@ def load(filename, fmt=None, reference=None, moving=None):
Examples
--------
>>> xfm = load(regress_dir / "affine-LAS.itk.tfm", fmt="itk")
>>> xfm = load(regress_dir / "affine-LAS.itk.tfm")
>>> isinstance(xfm, Affine)
True
>>> xfm = load(regress_dir / "itktflist.tfm", fmt="itk")
>>> xfm = load(regress_dir / "itktflist.tfm")
>>> isinstance(xfm, LinearTransformsMapping)
True
Expand Down
6 changes: 3 additions & 3 deletions nitransforms/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
FSLinearTransform as LT,
FSLinearTransformArray as LTA,
)
from ..io.base import LinearParameters, TransformFileError
from ..io.base import LinearParameters, TransformIOError, TransformFileError

LPS = np.diag([-1, -1, 1, 1])
ITK_MAT = LPS.dot(np.ones((4, 4)).dot(LPS))
Expand Down Expand Up @@ -224,7 +224,7 @@ def test_Linear_common(tmpdir, data_path, sw, image_orientation, get_testdata):

# Test without images
if sw == "fsl":
with pytest.raises(ValueError):
with pytest.raises(TransformIOError):
factory.from_ras(RAS)
else:
xfm = factory.from_ras(RAS)
Expand Down Expand Up @@ -422,7 +422,7 @@ def test_itk_h5(testdata_path):
== 2
)

with pytest.raises(RuntimeError):
with pytest.raises(TransformFileError):
list(
itk.ITKCompositeH5.from_filename(
testdata_path
Expand Down
42 changes: 29 additions & 13 deletions nitransforms/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ def test_linear_valueerror():
nitl.Affine(np.ones((4, 4)))


def test_linear_load_unsupported(data_path):
"""Exercise loading transform without I/O implementation."""
with pytest.raises(TypeError):
nitl.load(data_path / "itktflist2.tfm", fmt="X5")


def test_linear_load_mistaken(data_path):
"""Exercise loading transform without I/O implementation."""
with pytest.raises(io.TransformFileError):
nitl.load(data_path / "itktflist2.tfm", fmt="afni")


def test_loadsave_itk(tmp_path, data_path, testdata_path):
"""Test idempotency."""
ref_file = testdata_path / "someones_anatomy.nii.gz"
Expand All @@ -73,9 +85,13 @@ def test_loadsave_itk(tmp_path, data_path, testdata_path):
)


@pytest.mark.parametrize("autofmt", (False, True))
@pytest.mark.parametrize("fmt", ["itk", "fsl", "afni", "lta"])
def test_loadsave(tmp_path, data_path, testdata_path, fmt):
def test_loadsave(tmp_path, data_path, testdata_path, autofmt, fmt):
"""Test idempotency."""
supplied_fmt = None if autofmt else fmt

# Load reference transform
ref_file = testdata_path / "someones_anatomy.nii.gz"
xfm = nitl.load(data_path / "itktflist2.tfm", fmt="itk")
xfm.reference = ref_file
Expand All @@ -85,33 +101,33 @@ def test_loadsave(tmp_path, data_path, testdata_path, fmt):

if fmt == "fsl":
# FSL should not read a transform without reference
with pytest.raises(ValueError):
nitl.load(fname, fmt=fmt)
nitl.load(fname, fmt=fmt, moving=ref_file)
with pytest.raises(io.TransformIOError):
nitl.load(fname, fmt=supplied_fmt)
nitl.load(fname, fmt=supplied_fmt, moving=ref_file)

with pytest.warns(UserWarning):
assert np.allclose(
xfm.matrix,
nitl.load(fname, fmt=fmt, reference=ref_file).matrix,
nitl.load(fname, fmt=supplied_fmt, reference=ref_file).matrix,
)

assert np.allclose(
xfm.matrix,
nitl.load(fname, fmt=fmt, reference=ref_file, moving=ref_file).matrix,
nitl.load(fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file).matrix,
)
else:
assert xfm == nitl.load(fname, fmt=fmt, reference=ref_file)
assert xfm == nitl.load(fname, fmt=supplied_fmt, reference=ref_file)

xfm.to_filename(fname, fmt=fmt, moving=ref_file)

if fmt == "fsl":
assert np.allclose(
xfm.matrix,
nitl.load(fname, fmt=fmt, reference=ref_file, moving=ref_file).matrix,
nitl.load(fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file).matrix,
rtol=1e-2, # FSL incurs into large errors due to rounding
)
else:
assert xfm == nitl.load(fname, fmt=fmt, reference=ref_file)
assert xfm == nitl.load(fname, fmt=supplied_fmt, reference=ref_file)

ref_file = testdata_path / "someones_anatomy.nii.gz"
xfm = nitl.load(data_path / "affine-LAS.itk.tfm", fmt="itk")
Expand All @@ -121,21 +137,21 @@ def test_loadsave(tmp_path, data_path, testdata_path, fmt):
if fmt == "fsl":
assert np.allclose(
xfm.matrix,
nitl.load(fname, fmt=fmt, reference=ref_file, moving=ref_file).matrix,
nitl.load(fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file).matrix,
rtol=1e-2, # FSL incurs into large errors due to rounding
)
else:
assert xfm == nitl.load(fname, fmt=fmt, reference=ref_file)
assert xfm == nitl.load(fname, fmt=supplied_fmt, reference=ref_file)

xfm.to_filename(fname, fmt=fmt, moving=ref_file)
if fmt == "fsl":
assert np.allclose(
xfm.matrix,
nitl.load(fname, fmt=fmt, reference=ref_file, moving=ref_file).matrix,
nitl.load(fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file).matrix,
rtol=1e-2, # FSL incurs into large errors due to rounding
)
else:
assert xfm == nitl.load(fname, fmt=fmt, reference=ref_file)
assert xfm == nitl.load(fname, fmt=supplied_fmt, reference=ref_file)


@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
Expand Down

0 comments on commit 1fa7eeb

Please sign in to comment.