Skip to content

Commit

Permalink
Merge pull request #160 from nipy/enh/autoload-linear
Browse files Browse the repository at this point in the history
ENH: Guess open linear transform formats
  • Loading branch information
oesteban authored Apr 28, 2022
2 parents 2fa335d + bb8bf32 commit 1a34ccc
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 101 deletions.
25 changes: 24 additions & 1 deletion nitransforms/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,34 @@
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Read and write transforms."""
from . import afni, fsl, itk, lta
from nitransforms.io import afni, fsl, itk, lta
from nitransforms.io.base import TransformIOError, TransformFileError

__all__ = [
"afni",
"fsl",
"itk",
"lta",
"get_linear_factory",
"TransformFileError",
"TransformIOError",
]

_IO_TYPES = {
"itk": (itk, "ITKLinearTransform"),
"ants": (itk, "ITKLinearTransform"),
"elastix": (itk, "ITKLinearTransform"),
"lta": (lta, "FSLinearTransform"),
"fs": (lta, "FSLinearTransform"),
"fsl": (fsl, "FSLLinearTransform"),
"afni": (afni, "AFNILinearTransform"),
}


def get_linear_factory(fmt, is_array=True):
"""Return the type required by a given format."""
if fmt.lower() not in _IO_TYPES:
raise TypeError(f"Unsupported transform format <{fmt}>.")

module, classname = _IO_TYPES[fmt.lower()]
return getattr(module, f"{classname}{'Array' * is_array}")
14 changes: 9 additions & 5 deletions nitransforms/io/afni.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,16 @@ def from_string(cls, string):
if not lines:
raise TransformFileError

parameters = np.vstack(
(
np.genfromtxt([lines[0].encode()], dtype="f8").reshape((3, 4)),
(0.0, 0.0, 0.0, 1.0),
try:
parameters = np.vstack(
(
np.genfromtxt([lines[0].encode()], dtype="f8").reshape((3, 4)),
(0.0, 0.0, 0.0, 1.0),
)
)
)
except ValueError as e:
raise TransformFileError from e

sa["parameters"] = parameters
return tf

Expand Down
8 changes: 6 additions & 2 deletions nitransforms/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
from ..patched import LabeledWrapStruct


class TransformFileError(Exception):
"""A custom exception for transform files."""
class TransformIOError(IOError):
"""General I/O exception while reading/writing transforms."""


class TransformFileError(TransformIOError):
"""Specific I/O exception when a file does not meet the expected format."""


class StringBasedStruct(LabeledWrapStruct):
Expand Down
5 changes: 3 additions & 2 deletions nitransforms/io/fsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BaseLinearTransformList,
LinearParameters,
DisplacementsField,
TransformIOError,
TransformFileError,
_ensure_image,
)
Expand Down Expand Up @@ -40,7 +41,7 @@ def from_ras(cls, ras, moving=None, reference=None):
moving = reference

if reference is None:
raise ValueError("Cannot build FSL linear transform without a reference")
raise TransformIOError("Cannot build FSL linear transform without a reference")

reference = _ensure_image(reference)
moving = _ensure_image(moving)
Expand Down Expand Up @@ -77,7 +78,7 @@ def from_string(cls, string):
def to_ras(self, moving=None, reference=None):
"""Return a nitransforms internal RAS+ matrix."""
if reference is None:
raise ValueError("Cannot build FSL linear transform without a reference")
raise TransformIOError("Cannot build FSL linear transform without a reference")

if moving is None:
warnings.warn(
Expand Down
5 changes: 3 additions & 2 deletions nitransforms/io/itk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
BaseLinearTransformList,
DisplacementsField,
LinearParameters,
TransformIOError,
TransformFileError,
)

Expand Down Expand Up @@ -306,7 +307,7 @@ def from_filename(cls, filename):
from h5py import File as H5File

if not str(filename).endswith(".h5"):
raise RuntimeError("Extension is not .h5")
raise TransformFileError("Extension is not .h5")

with H5File(str(filename)) as f:
return cls.from_h5obj(f)
Expand Down Expand Up @@ -354,7 +355,7 @@ def from_h5obj(cls, fileobj, check=True):
)
continue

raise NotImplementedError(
raise TransformIOError(
f"Unsupported transform type {xfm['TransformType'][0]}"
)

Expand Down
114 changes: 45 additions & 69 deletions nitransforms/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

from nibabel.loadsave import load as _nbload

from .base import (
from nitransforms.base import (
ImageGrid,
TransformBase,
SpatialReference,
_as_homogeneous,
EQUALITY_TOL,
)
from . import io
from nitransforms.io import get_linear_factory, TransformFileError


class Affine(TransformBase):
Expand Down Expand Up @@ -183,51 +183,40 @@ def _to_hdf5(self, x5_root):
self.reference._to_hdf5(x5_root.create_group("Reference"))

def to_filename(self, filename, fmt="X5", moving=None):
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
if fmt.lower() in ["itk", "ants", "elastix"]:
itkobj = io.itk.ITKLinearTransform.from_ras(self.matrix)
itkobj.to_filename(filename)
return filename

# Rest of the formats peek into moving and reference image grids
moving = ImageGrid(moving) if moving is not None else self.reference

_factory = {
"afni": io.afni.AFNILinearTransform,
"fsl": io.fsl.FSLLinearTransform,
"lta": io.lta.FSLinearTransform,
"fs": io.lta.FSLinearTransform,
}

if fmt not in _factory:
raise NotImplementedError(f"Unsupported format <{fmt}>")

_factory[fmt].from_ras(
self.matrix, moving=moving, reference=self.reference
).to_filename(filename)
return filename
"""Store the transform in the requested output format."""
writer = get_linear_factory(fmt, is_array=False)

@classmethod
def from_filename(cls, filename, fmt="X5", reference=None, moving=None):
"""Create an affine from a transform file."""
if fmt.lower() in ("itk", "ants", "elastix"):
_factory = io.itk.ITKLinearTransformArray
elif fmt.lower() in ("lta", "fs"):
_factory = io.lta.FSLinearTransformArray
elif fmt.lower() == "fsl":
_factory = io.fsl.FSLLinearTransformArray
elif fmt.lower() == "afni":
_factory = io.afni.AFNILinearTransformArray
writer.from_ras(self.matrix).to_filename(filename)
else:
raise NotImplementedError
# Rest of the formats peek into moving and reference image grids
writer.from_ras(
self.matrix,
reference=self.reference,
moving=ImageGrid(moving) if moving is not None else self.reference,
).to_filename(filename)
return filename

struct = _factory.from_filename(filename)
matrix = struct.to_ras(reference=reference, moving=moving)
if cls == Affine:
if np.shape(matrix)[0] != 1:
raise TypeError("Cannot load transform array '%s'" % filename)
matrix = matrix[0]
return cls(matrix, reference=reference)
@classmethod
def from_filename(cls, filename, fmt=None, reference=None, moving=None):
"""Create an affine from a transform file."""
fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl")

for potential_fmt in fmtlist:
try:
struct = get_linear_factory(potential_fmt).from_filename(filename)
matrix = struct.to_ras(reference=reference, moving=moving)
if cls == Affine:
if np.shape(matrix)[0] != 1:
raise TypeError("Cannot load transform array '%s'" % filename)
matrix = matrix[0]
return cls(matrix, reference=reference)
except (TransformFileError, FileNotFoundError):
continue

raise TransformFileError(
f"Could not open <{filename}> (formats tried: {', '.join(fmtlist)})."
)

def __repr__(self):
"""
Expand Down Expand Up @@ -353,31 +342,18 @@ def map(self, x, inverse=False):
return np.swapaxes(affine.dot(coords), 1, 2)

def to_filename(self, filename, fmt="X5", moving=None):
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
if fmt.lower() in ("itk", "ants", "elastix"):
itkobj = io.itk.ITKLinearTransformArray.from_ras(self.matrix)
itkobj.to_filename(filename)
return filename
"""Store the transform in the requested output format."""
writer = get_linear_factory(fmt, is_array=True)

# Rest of the formats peek into moving and reference image grids
if moving is not None:
moving = ImageGrid(moving)
if fmt.lower() in ("itk", "ants", "elastix"):
writer.from_ras(self.matrix).to_filename(filename)
else:
moving = self.reference

_factory = {
"afni": io.afni.AFNILinearTransformArray,
"fsl": io.fsl.FSLLinearTransformArray,
"lta": io.lta.FSLinearTransformArray,
"fs": io.lta.FSLinearTransformArray,
}

if fmt not in _factory:
raise NotImplementedError(f"Unsupported format <{fmt}>")

_factory[fmt].from_ras(
self.matrix, moving=moving, reference=self.reference
).to_filename(filename)
# Rest of the formats peek into moving and reference image grids
writer.from_ras(
self.matrix,
reference=self.reference,
moving=ImageGrid(moving) if moving is not None else self.reference,
).to_filename(filename)
return filename

def apply(
Expand Down Expand Up @@ -486,17 +462,17 @@ def apply(
return resampled


def load(filename, fmt="X5", reference=None, moving=None):
def load(filename, fmt=None, reference=None, moving=None):
"""
Load a linear transform file.
Examples
--------
>>> xfm = load(regress_dir / "affine-LAS.itk.tfm", fmt="itk")
>>> xfm = load(regress_dir / "affine-LAS.itk.tfm")
>>> isinstance(xfm, Affine)
True
>>> xfm = load(regress_dir / "itktflist.tfm", fmt="itk")
>>> xfm = load(regress_dir / "itktflist.tfm")
>>> isinstance(xfm, LinearTransformsMapping)
True
Expand Down
25 changes: 21 additions & 4 deletions nitransforms/tests/test_io.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""I/O test cases."""
import os
from subprocess import check_call
from io import StringIO
import filecmp
import shutil
import numpy as np
import pytest
from h5py import File as H5File

import nibabel as nb
from nibabel.eulerangles import euler2mat
Expand All @@ -24,7 +26,7 @@
FSLinearTransform as LT,
FSLinearTransformArray as LTA,
)
from ..io.base import LinearParameters, TransformFileError
from ..io.base import LinearParameters, TransformIOError, TransformFileError

LPS = np.diag([-1, -1, 1, 1])
ITK_MAT = LPS.dot(np.ones((4, 4)).dot(LPS))
Expand Down Expand Up @@ -224,7 +226,7 @@ def test_Linear_common(tmpdir, data_path, sw, image_orientation, get_testdata):

# Test without images
if sw == "fsl":
with pytest.raises(ValueError):
with pytest.raises(TransformIOError):
factory.from_ras(RAS)
else:
xfm = factory.from_ras(RAS)
Expand Down Expand Up @@ -408,7 +410,7 @@ def test_afni_Displacements():
afni.AFNIDisplacementsField.from_image(field)


def test_itk_h5(testdata_path):
def test_itk_h5(tmpdir, testdata_path):
"""Test displacements fields."""
assert (
len(
Expand All @@ -422,14 +424,29 @@ def test_itk_h5(testdata_path):
== 2
)

with pytest.raises(RuntimeError):
with pytest.raises(TransformFileError):
list(
itk.ITKCompositeH5.from_filename(
testdata_path
/ "ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.x5"
)
)

tmpdir.chdir()
shutil.copy(
testdata_path / "ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.h5",
"test.h5",
)
os.chmod("test.h5", 0o666)

with H5File("test.h5", "r+") as h5file:
h5group = h5file["TransformGroup"]
xfm = h5group[list(h5group.keys())[1]]
xfm["TransformType"][0] = b"InventTransform"

with pytest.raises(TransformIOError):
itk.ITKCompositeH5.from_filename("test.h5")


@pytest.mark.parametrize(
"file_type, test_file", [(LTA, "from-fsnative_to-scanner_mode-image.lta")]
Expand Down
Loading

0 comments on commit 1a34ccc

Please sign in to comment.