Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Guess open linear transform formats #160

Merged
merged 3 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion nitransforms/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,34 @@
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Read and write transforms."""
from . import afni, fsl, itk, lta
from nitransforms.io import afni, fsl, itk, lta
from nitransforms.io.base import TransformIOError, TransformFileError

__all__ = [
"afni",
"fsl",
"itk",
"lta",
"get_linear_factory",
"TransformFileError",
"TransformIOError",
]

_IO_TYPES = {
"itk": (itk, "ITKLinearTransform"),
"ants": (itk, "ITKLinearTransform"),
"elastix": (itk, "ITKLinearTransform"),
"lta": (lta, "FSLinearTransform"),
"fs": (lta, "FSLinearTransform"),
"fsl": (fsl, "FSLLinearTransform"),
"afni": (afni, "AFNILinearTransform"),
}


def get_linear_factory(fmt, is_array=True):
"""Return the type required by a given format."""
if fmt.lower() not in _IO_TYPES:
raise TypeError(f"Unsupported transform format <{fmt}>.")

module, classname = _IO_TYPES[fmt.lower()]
return getattr(module, f"{classname}{'Array' * is_array}")
14 changes: 9 additions & 5 deletions nitransforms/io/afni.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,16 @@ def from_string(cls, string):
if not lines:
raise TransformFileError

parameters = np.vstack(
(
np.genfromtxt([lines[0].encode()], dtype="f8").reshape((3, 4)),
(0.0, 0.0, 0.0, 1.0),
try:
parameters = np.vstack(
(
np.genfromtxt([lines[0].encode()], dtype="f8").reshape((3, 4)),
(0.0, 0.0, 0.0, 1.0),
)
)
)
except ValueError as e:
raise TransformFileError from e

sa["parameters"] = parameters
return tf

Expand Down
8 changes: 6 additions & 2 deletions nitransforms/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
from ..patched import LabeledWrapStruct


class TransformFileError(Exception):
"""A custom exception for transform files."""
class TransformIOError(IOError):
"""General I/O exception while reading/writing transforms."""


class TransformFileError(TransformIOError):
"""Specific I/O exception when a file does not meet the expected format."""


class StringBasedStruct(LabeledWrapStruct):
Expand Down
5 changes: 3 additions & 2 deletions nitransforms/io/fsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BaseLinearTransformList,
LinearParameters,
DisplacementsField,
TransformIOError,
TransformFileError,
_ensure_image,
)
Expand Down Expand Up @@ -40,7 +41,7 @@ def from_ras(cls, ras, moving=None, reference=None):
moving = reference

if reference is None:
raise ValueError("Cannot build FSL linear transform without a reference")
raise TransformIOError("Cannot build FSL linear transform without a reference")

reference = _ensure_image(reference)
moving = _ensure_image(moving)
Expand Down Expand Up @@ -77,7 +78,7 @@ def from_string(cls, string):
def to_ras(self, moving=None, reference=None):
"""Return a nitransforms internal RAS+ matrix."""
if reference is None:
raise ValueError("Cannot build FSL linear transform without a reference")
raise TransformIOError("Cannot build FSL linear transform without a reference")

if moving is None:
warnings.warn(
Expand Down
5 changes: 3 additions & 2 deletions nitransforms/io/itk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
BaseLinearTransformList,
DisplacementsField,
LinearParameters,
TransformIOError,
TransformFileError,
)

Expand Down Expand Up @@ -306,7 +307,7 @@ def from_filename(cls, filename):
from h5py import File as H5File

if not str(filename).endswith(".h5"):
raise RuntimeError("Extension is not .h5")
raise TransformFileError("Extension is not .h5")

with H5File(str(filename)) as f:
return cls.from_h5obj(f)
Expand Down Expand Up @@ -355,7 +356,7 @@ def from_h5obj(cls, fileobj, check=True):
)
continue

raise NotImplementedError(
raise TransformIOError(
f"Unsupported transform type {xfm['TransformType'][0]}"
)

Expand Down
114 changes: 45 additions & 69 deletions nitransforms/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

from nibabel.loadsave import load as _nbload

from .base import (
from nitransforms.base import (
ImageGrid,
TransformBase,
SpatialReference,
_as_homogeneous,
EQUALITY_TOL,
)
from . import io
from nitransforms.io import get_linear_factory, TransformFileError


class Affine(TransformBase):
Expand Down Expand Up @@ -183,51 +183,40 @@ def _to_hdf5(self, x5_root):
self.reference._to_hdf5(x5_root.create_group("Reference"))

def to_filename(self, filename, fmt="X5", moving=None):
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
if fmt.lower() in ["itk", "ants", "elastix"]:
itkobj = io.itk.ITKLinearTransform.from_ras(self.matrix)
itkobj.to_filename(filename)
return filename

# Rest of the formats peek into moving and reference image grids
moving = ImageGrid(moving) if moving is not None else self.reference

_factory = {
"afni": io.afni.AFNILinearTransform,
"fsl": io.fsl.FSLLinearTransform,
"lta": io.lta.FSLinearTransform,
"fs": io.lta.FSLinearTransform,
}

if fmt not in _factory:
raise NotImplementedError(f"Unsupported format <{fmt}>")

_factory[fmt].from_ras(
self.matrix, moving=moving, reference=self.reference
).to_filename(filename)
return filename
"""Store the transform in the requested output format."""
writer = get_linear_factory(fmt, is_array=False)

@classmethod
def from_filename(cls, filename, fmt="X5", reference=None, moving=None):
"""Create an affine from a transform file."""
if fmt.lower() in ("itk", "ants", "elastix"):
_factory = io.itk.ITKLinearTransformArray
elif fmt.lower() in ("lta", "fs"):
_factory = io.lta.FSLinearTransformArray
elif fmt.lower() == "fsl":
_factory = io.fsl.FSLLinearTransformArray
elif fmt.lower() == "afni":
_factory = io.afni.AFNILinearTransformArray
writer.from_ras(self.matrix).to_filename(filename)
else:
raise NotImplementedError
# Rest of the formats peek into moving and reference image grids
writer.from_ras(
self.matrix,
reference=self.reference,
moving=ImageGrid(moving) if moving is not None else self.reference,
).to_filename(filename)
return filename

struct = _factory.from_filename(filename)
matrix = struct.to_ras(reference=reference, moving=moving)
if cls == Affine:
if np.shape(matrix)[0] != 1:
raise TypeError("Cannot load transform array '%s'" % filename)
matrix = matrix[0]
return cls(matrix, reference=reference)
@classmethod
def from_filename(cls, filename, fmt=None, reference=None, moving=None):
"""Create an affine from a transform file."""
fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl")

for potential_fmt in fmtlist:
try:
struct = get_linear_factory(potential_fmt).from_filename(filename)
matrix = struct.to_ras(reference=reference, moving=moving)
if cls == Affine:
if np.shape(matrix)[0] != 1:
raise TypeError("Cannot load transform array '%s'" % filename)
matrix = matrix[0]
return cls(matrix, reference=reference)
except (TransformFileError, FileNotFoundError):
continue

raise TransformFileError(
f"Could not open <{filename}> (formats tried: {', '.join(fmtlist)})."
)

def __repr__(self):
"""
Expand Down Expand Up @@ -353,31 +342,18 @@ def map(self, x, inverse=False):
return np.swapaxes(affine.dot(coords), 1, 2)

def to_filename(self, filename, fmt="X5", moving=None):
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
if fmt.lower() in ("itk", "ants", "elastix"):
itkobj = io.itk.ITKLinearTransformArray.from_ras(self.matrix)
itkobj.to_filename(filename)
return filename
"""Store the transform in the requested output format."""
writer = get_linear_factory(fmt, is_array=True)

# Rest of the formats peek into moving and reference image grids
if moving is not None:
moving = ImageGrid(moving)
if fmt.lower() in ("itk", "ants", "elastix"):
writer.from_ras(self.matrix).to_filename(filename)
else:
moving = self.reference

_factory = {
"afni": io.afni.AFNILinearTransformArray,
"fsl": io.fsl.FSLLinearTransformArray,
"lta": io.lta.FSLinearTransformArray,
"fs": io.lta.FSLinearTransformArray,
}

if fmt not in _factory:
raise NotImplementedError(f"Unsupported format <{fmt}>")

_factory[fmt].from_ras(
self.matrix, moving=moving, reference=self.reference
).to_filename(filename)
# Rest of the formats peek into moving and reference image grids
writer.from_ras(
self.matrix,
reference=self.reference,
moving=ImageGrid(moving) if moving is not None else self.reference,
).to_filename(filename)
return filename

def apply(
Expand Down Expand Up @@ -486,17 +462,17 @@ def apply(
return resampled


def load(filename, fmt="X5", reference=None, moving=None):
def load(filename, fmt=None, reference=None, moving=None):
"""
Load a linear transform file.

Examples
--------
>>> xfm = load(regress_dir / "affine-LAS.itk.tfm", fmt="itk")
>>> xfm = load(regress_dir / "affine-LAS.itk.tfm")
>>> isinstance(xfm, Affine)
True

>>> xfm = load(regress_dir / "itktflist.tfm", fmt="itk")
>>> xfm = load(regress_dir / "itktflist.tfm")
>>> isinstance(xfm, LinearTransformsMapping)
True

Expand Down
25 changes: 21 additions & 4 deletions nitransforms/tests/test_io.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""I/O test cases."""
import os
from subprocess import check_call
from io import StringIO
import filecmp
import shutil
import numpy as np
import pytest
from h5py import File as H5File

import nibabel as nb
from nibabel.eulerangles import euler2mat
Expand All @@ -24,7 +26,7 @@
FSLinearTransform as LT,
FSLinearTransformArray as LTA,
)
from ..io.base import LinearParameters, TransformFileError
from ..io.base import LinearParameters, TransformIOError, TransformFileError

LPS = np.diag([-1, -1, 1, 1])
ITK_MAT = LPS.dot(np.ones((4, 4)).dot(LPS))
Expand Down Expand Up @@ -224,7 +226,7 @@ def test_Linear_common(tmpdir, data_path, sw, image_orientation, get_testdata):

# Test without images
if sw == "fsl":
with pytest.raises(ValueError):
with pytest.raises(TransformIOError):
factory.from_ras(RAS)
else:
xfm = factory.from_ras(RAS)
Expand Down Expand Up @@ -408,7 +410,7 @@ def test_afni_Displacements():
afni.AFNIDisplacementsField.from_image(field)


def test_itk_h5(testdata_path):
def test_itk_h5(tmpdir, testdata_path):
"""Test displacements fields."""
assert (
len(
Expand All @@ -422,14 +424,29 @@ def test_itk_h5(testdata_path):
== 2
)

with pytest.raises(RuntimeError):
with pytest.raises(TransformFileError):
list(
itk.ITKCompositeH5.from_filename(
testdata_path
/ "ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.x5"
)
)

tmpdir.chdir()
shutil.copy(
testdata_path / "ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.h5",
"test.h5",
)
os.chmod("test.h5", 0o666)

with H5File("test.h5", "r+") as h5file:
h5group = h5file["TransformGroup"]
xfm = h5group[list(h5group.keys())[1]]
xfm["TransformType"][0] = b"InventTransform"

with pytest.raises(TransformIOError):
itk.ITKCompositeH5.from_filename("test.h5")


@pytest.mark.parametrize(
"file_type, test_file", [(LTA, "from-fsnative_to-scanner_mode-image.lta")]
Expand Down
Loading