diff --git a/mne/annotations.py b/mne/annotations.py index 629ee7b20cb..b60ea1502e0 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -36,6 +36,7 @@ _check_option, _check_pandas_installed, _check_time_format, + _check_wfdb_installed, _convert_times, _DefaultEventParser, _dt_to_stamp, @@ -1140,7 +1141,14 @@ def _write_annotations_txt(fname, annot): @fill_doc def read_annotations( - fname, sfreq="auto", uint16_codec=None, encoding="utf8", ignore_marker_types=False + fname, + sfreq="auto", + *, + uint16_codec=None, + encoding="utf8", + ignore_marker_types=False, + fmt="auto", + suffix=None, ) -> Annotations: r"""Read annotations from a file. @@ -1222,6 +1230,7 @@ def read_annotations( ".bdf": {"encoding": encoding}, ".gdf": {"encoding": encoding}, } + print(fname.suffix) if fname.suffix in readers: annotations = readers[fname.suffix](fname, **kwargs.get(fname.suffix, {})) elif fname.name.endswith(("fif", "fif.gz")): @@ -1231,6 +1240,8 @@ def read_annotations( annotations = _read_annotations_fif(fid, tree) elif fname.name.startswith("events_") and fname.suffix == ".mat": annotations = _read_brainstorm_annotations(fname) + elif fmt == "wfdb": + annotations = _read_wfdb_annotations(fname, suffix=suffix) else: raise OSError(f'Unknown annotation file format "{fname}"') @@ -1513,6 +1524,18 @@ def _check_event_description(event_desc, events): return event_desc +def _read_wfdb_annotations(fname, suffix=None, sfreq="auto"): + """Read annotations from wfdb format.""" + wfdb = _check_wfdb_installed(strict=True) + anno = wfdb.io.rdann(fname.stem, extension=suffix) + anno_dict = anno.__dict__ + sfreq = anno_dict["fs"] if sfreq == "auto" else sfreq + onset = anno_dict["sample"] / sfreq + duration = np.zeros_like(onset) + description = anno_dict["symbol"] + return Annotations(onset, duration, description) + + @verbose def events_from_annotations( raw, diff --git a/mne/utils/__init__.pyi b/mne/utils/__init__.pyi index 46d272e972d..770afd0bde6 100644 --- a/mne/utils/__init__.pyi +++ b/mne/utils/__init__.pyi @@ -48,6 +48,7 @@ __all__ = [ "_check_preload", "_check_pybv_installed", "_check_pymatreader_installed", + "_check_wfdb_installed", "_check_qt_version", "_check_range", "_check_rank", @@ -254,6 +255,7 @@ from .check import ( _check_stc_units, _check_subject, _check_time_format, + _check_wfdb_installed, _ensure_events, _ensure_int, _import_h5io_funcs, diff --git a/mne/utils/check.py b/mne/utils/check.py index 973fa33fe79..723b88b05bc 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -453,6 +453,11 @@ def _check_edfio_installed(strict=True): return _soft_import("edfio", "exporting to EDF", strict=strict) +def _check_wfdb_installed(strict=True): + """Aux function.""" + return _soft_import("wfdb", "MIT WFDB IO", strict=strict) + + def _check_pybv_installed(strict=True): """Aux function.""" return _soft_import("pybv", "exporting to BrainVision", strict=strict)