Skip to content

Commit

Permalink
Merge pull request xraypy#412 from woutdenolf/xas_data_source
Browse files Browse the repository at this point in the history
XasDataSource: common API to XAS data
  • Loading branch information
newville authored Mar 14, 2024
2 parents af3d08a + 66372fe commit dc59aaf
Show file tree
Hide file tree
Showing 13 changed files with 1,826 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ Thumbs.db
.vscode/
.empty/
.eggs/
.coverage
5 changes: 4 additions & 1 deletion larch/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .stepscan_file import read_stepscan

from .nexus_xas import NXxasFile
from .xas_data_source import open_xas_source, read_xas_source

def read_tiff(fname, *args, **kws):
"""read image data from a TIFF file as an array"""
Expand Down Expand Up @@ -103,7 +104,9 @@ def read_tiff(fname, *args, **kws):
str2rng=str2rng_larch,
read_specfile=read_specfile,
specfile=open_specfile,
read_fdmnes=read_fdmnes
read_fdmnes=read_fdmnes,
open_xas_source=open_xas_source,
read_xas_source=read_xas_source
)

_larch_builtins = {'_io':__exports__}
2 changes: 2 additions & 0 deletions larch/io/xas_data_source/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .read import open_xas_source # noqa F401
from .read import read_xas_source # noqa F401
31 changes: 31 additions & 0 deletions larch/io/xas_data_source/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import NamedTuple, List, Optional
from numpy.typing import ArrayLike


class XasScan(NamedTuple):
name: str
description: str
info: str
start_time: str
labels: List[str]
data: ArrayLike


class XasDataSource:
TYPE = NotImplemented

def __init__(self, filename: str) -> None:
self._filename = filename

def get_source_info(self) -> str:
raise NotImplementedError

def get_scan(self, scan_name: str) -> Optional[XasScan]:
raise NotImplementedError

def get_scan_names(self) -> List[str]:
raise NotImplementedError

def get_sorted_scan_names(self) -> List[str]:
scan_names = self.get_scan_names()
return sorted(scan_names, key=lambda s: float(s) if s.isdigit() else s)
72 changes: 72 additions & 0 deletions larch/io/xas_data_source/hdf5_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from contextlib import contextmanager
from typing import Iterator, Optional
import h5py

H5PY_VERSION = h5py.version.version_tuple[:3]
H5PY_HAS_LOCKING = H5PY_VERSION >= (3, 5)


@contextmanager
def open(filename) -> Iterator[h5py.File]:
kw = {"mode": "r"}
if H5PY_HAS_LOCKING:
kw["locking"] = False
with h5py.File(filename, **kw) as f:
yield f


def nexus_creator(filename: str) -> str:
with open(filename) as nxroot:
return nxroot.attrs.get("creator", "")


def nexus_instrument(filename: str) -> str:
with open(filename) as nxroot:
entry = find_nexus_class(nxroot, "NXentry")
if entry is None:
return ""

instrument = find_nexus_class(entry, "NXinstrument")
if instrument is None:
return ""

if "name" in instrument:
return asstr(instrument["name"][()])
return ""


def nexus_source(filename: str) -> str:
with open(filename) as nxroot:
entry = find_nexus_class(nxroot, "NXentry")
if entry is None:
return ""

source = find_nexus_class(entry, "NXsource")
if source is None:
instrument = find_nexus_class(entry, "NXinstrument")
if instrument is None:
return ""
source = find_nexus_class(instrument, "NXsource")
if source is None:
return ""

if "name" in source:
return asstr(source["name"][()])
return ""


def asstr(s):
if isinstance(s, bytes):
return s.decode()
return s


def find_nexus_class(parent: h5py.Group, nxclass: str) -> Optional[h5py.Group]:
for name in parent:
try:
child = parent[name]
except KeyError:
continue # broken line
if asstr(child.attrs.get("NX_class", "")) != nxclass:
continue
return child
160 changes: 160 additions & 0 deletions larch/io/xas_data_source/nexus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import re
from contextlib import contextmanager
from typing import Iterator, List, Optional, Tuple
import numpy
import h5py
from . import base
from . import hdf5_utils


class NexusSingleXasDataSource(base.XasDataSource):
"""NeXus compliant HDF5 file. Each NXentry contains 1 XAS spectrum."""

TYPE = "HDF5-NEXUS"

def __init__(
self,
filename: str,
title_regex_pattern: Optional[str] = None,
counter_group: Optional[str] = None,
**kw,
) -> None:
self._nxroot = None
if title_regex_pattern:
title_regex_pattern = re.compile(title_regex_pattern)
self._title_regex_pattern = title_regex_pattern
self._counter_group = counter_group
self._instrument = None
super().__init__(filename, **kw)

def get_source_info(self) -> str:
return f"HDF5: {self._filename}"

def get_scan(self, scan_name: str) -> Optional[base.XasScan]:
with self._open() as nxroot:
scan = nxroot[scan_name]
datasets = sorted(self._iter_datasets(scan), key=lambda tpl: tpl[0])
if datasets:
labels, data = zip(*datasets)
else:
labels = list()
data = list()
description = self._get_string(scan, "title")
if not description:
description = scan_name
start_time = self._get_string(scan, "start_time")
return base.XasScan(
name=scan_name,
description=description,
start_time=start_time,
info=description,
labels=list(labels),
data=numpy.asarray(data),
)

def get_scan_names(self) -> List[str]:
return list(self._iter_scan_names())

def _iter_scan_names(self) -> Iterator[str]:
with self._open() as nxroot:
for name in nxroot["/"]: # index at "/" to preserve order
try:
scan = nxroot[name]
except KeyError:
continue # broken link
if self._title_regex_pattern is not None:
title = self._get_string(scan, "title")
if not self._title_regex_pattern.match(title):
continue
yield name

@contextmanager
def _open(self) -> Iterator[h5py.File]:
"""Re-entrant context to get access to the HDF5 file"""
if self._nxroot is not None:
yield self._nxroot
return
with hdf5_utils.open(self._filename) as nxroot:
self._nxroot = nxroot
try:
yield nxroot
finally:
self._nxroot = None

def _iter_datasets(self, scan: h5py.Group) -> Iterator[Tuple[str, h5py.Dataset]]:
if self._counter_group:
yield from self._iter_counter_group(scan)
else:
yield from self._iter_instrument_group(scan)

def _iter_counter_group(
self, scan: h5py.Group
) -> Iterator[Tuple[str, h5py.Dataset]]:
try:
counter_group = scan[self._counter_group]
except KeyError:
return # broken link or not existing
for name in counter_group:
try:
dset = counter_group[name]
except KeyError:
continue # broken link
if not hasattr(dset, "ndim"):
continue
if dset.ndim == 1:
yield name, dset

def _iter_instrument_group(
self, scan: h5py.Group
) -> Iterator[Tuple[str, h5py.Dataset]]:
instrument = self._get_instrument(scan)
if instrument is None:
return
for name in instrument:
try:
detector = instrument[name]
except KeyError:
continue # broken link
nxclass = detector.attrs.get("NX_class")
if nxclass not in ("NXdetector", "NXpositioner"):
continue
try:
if nxclass == "NXpositioner":
dset = detector["value"]
else:
dset = detector["data"]
except KeyError:
continue # no data
if dset.ndim == 1:
yield name, dset

def _get_instrument(self, scan: h5py.Group) -> Optional[h5py.Group]:
if self._instrument:
return scan[self._instrument]
instrument = hdf5_utils.find_nexus_class(scan, "NXinstrument")
if instrument is not None:
self._instrument = instrument.name.split("/")[-1]
return instrument

def _get_string(self, group: h5py.Group, name) -> str:
try:
s = group[name][()]
except KeyError:
return ""
return hdf5_utils.asstr(s)


class EsrfSingleXasDataSource(NexusSingleXasDataSource):
TYPE = "HDF5-NEXUS-ESRF"

def __init__(self, filename: str, **kw) -> None:
kw.setdefault("counter_group", "measurement")
super().__init__(filename, **kw)


class SoleilSingleXasDataSource(NexusSingleXasDataSource):
TYPE = "HDF5-NEXUS-SOLEIL"

def __init__(self, filename: str, **kw) -> None:
kw.setdefault("counter_group", "scan_data")
super().__init__(filename, **kw)
49 changes: 49 additions & 0 deletions larch/io/xas_data_source/read.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Optional
from . import sources
from . import hdf5_utils
from larch import Group


def open_xas_source(filename, **kw):
with open(filename, "rb") as fh:
topbytes = fh.read(10)

if topbytes.startswith(b"\x89HDF\r"):
creator = hdf5_utils.nexus_creator(filename).lower()
class_name = None
if creator == "bliss":
class_name = "esrf"
if not class_name:
source = hdf5_utils.nexus_source(filename).lower()
if "soleil" in source:
class_name = "soleil"
if not class_name:
instrument = hdf5_utils.nexus_instrument(filename).lower()
if "soleil" in instrument:
class_name = "soleil"
if not class_name:
class_name = "nexus"
elif topbytes.startswith(b"#S ") or topbytes.startswith(b"#F "):
class_name = "spec"
else:
raise ValueError(f"Unknown file format: {filename}")
return sources.get_source_type(class_name)(filename, **kw)


def read_xas_source(filename: str, scan: Optional[str] = None) -> Optional[Group]:
if scan is None:
return None
source = open_xas_source(filename)
scan = source.get_scan(scan)

lgroup = Group(
__name__=f"{source.TYPE} file: {filename}, scan: {scan.name}",
filename=filename,
source_info=source.get_source_info(),
datatype="xas",
)
for name, value in scan._asdict().items():
setattr(lgroup, name, value)
for name, value in zip(scan.labels, scan.data):
setattr(lgroup, name, value)
return lgroup
15 changes: 15 additions & 0 deletions larch/io/xas_data_source/sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .nexus import NexusSingleXasDataSource
from .nexus import EsrfSingleXasDataSource
from .nexus import SoleilSingleXasDataSource
from .spec import SpecSingleXasDataSource

_SOURCE_TYPES = {
"nexus": NexusSingleXasDataSource,
"esrf": EsrfSingleXasDataSource,
"soleil": SoleilSingleXasDataSource,
"spec": SpecSingleXasDataSource,
}


def get_source_type(name):
return _SOURCE_TYPES[name]
Loading

0 comments on commit dc59aaf

Please sign in to comment.