Skip to content

Commit

Permalink
Handle channel information in annotations when loading data from and …
Browse files Browse the repository at this point in the history
…exporting to EDF file (#11960)

Co-authored-by: Paul ROUJANSKY <[email protected]>
  • Loading branch information
2 people authored and larsoner committed Sep 6, 2023
1 parent f95a9c5 commit 209595b
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 24 deletions.
1 change: 1 addition & 0 deletions doc/changes/1.5.inc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Bugs
- Fix bug with :func:`mne.chpi.compute_head_pos` for CTF data where digitization points were modified in-place, producing an incorrect result during a save-load round-trip (:gh:`11934` by `Eric Larson`_)
- Fix bug with notebooks when using PyVista 0.42 by implementing ``trame`` backend support (:gh:`11956` by `Eric Larson`_)
- Fix bug with ``subject_info`` when loading data from and exporting to EDF file (:gh:`11952` by `Paul Roujansky`_)
- Fix handling of channel information in annotations when loading data from and exporting to EDF file (:gh:`11960` by `Paul Roujansky`_)


.. _changes_1_5_0:
Expand Down
8 changes: 6 additions & 2 deletions mne/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,11 +1231,15 @@ def read_annotations(fname, sfreq="auto", uint16_codec=None):
annotations = _read_annotations_eeglab(fname, uint16_codec=uint16_codec)

elif name.endswith(("edf", "bdf", "gdf")):
onset, duration, description = _read_annotations_edf(fname)
onset, duration, description, ch_names = _read_annotations_edf(fname)
onset = np.array(onset, dtype=float)
duration = np.array(duration, dtype=float)
annotations = Annotations(
onset=onset, duration=duration, description=description, orig_time=None
onset=onset,
duration=duration,
description=description,
orig_time=None,
ch_names=ch_names,
)

elif name.startswith("events_") and fname.endswith("mat"):
Expand Down
31 changes: 21 additions & 10 deletions mne/export/_edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,7 @@ def _export_raw(fname, raw, physical_range, add_ch_type):
annots = raw.annotations
if annots is not None:
n_annotations = len(raw.annotations)
n_annot_chans = int(n_annotations / n_blocks)
if np.mod(n_annotations, n_blocks):
n_annot_chans += 1
n_annot_chans = int(n_annotations / n_blocks) + 1
if n_annot_chans > 1:
hdl.setNumberOfAnnotationSignals(n_annot_chans)

Expand Down Expand Up @@ -305,17 +303,30 @@ def _export_raw(fname, raw, physical_range, add_ch_type):

# write annotations
if annots is not None:
for desc, onset, duration in zip(
for desc, onset, duration, ch_names in zip(
raw.annotations.description,
raw.annotations.onset,
raw.annotations.duration,
raw.annotations.ch_names,
):
# annotations are written in terms of 100 microseconds
onset = onset * 10000
duration = duration * 10000
if hdl.writeAnnotation(onset, duration, desc) != 0:
raise RuntimeError(
f"writeAnnotation() returned an error "
f"trying to write {desc} at {onset} "
f"for {duration} seconds."
)
if ch_names:
for ch_name in ch_names:
if (
hdl.writeAnnotation(onset, duration, desc + f"@@{ch_name}")
!= 0
):
raise RuntimeError(
f"writeAnnotation() returned an error "
f"trying to write {desc}@@{ch_name} at {onset} "
f"for {duration} seconds."
)
else:
if hdl.writeAnnotation(onset, duration, desc) != 0:
raise RuntimeError(
f"writeAnnotation() returned an error "
f"trying to write {desc} at {onset} "
f"for {duration} seconds."
)
2 changes: 2 additions & 0 deletions mne/export/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def test_export_edf_annotations(tmp_path):
onset=[0.01, 0.05, 0.90, 1.05],
duration=[0, 1, 0, 0],
description=["test1", "test2", "test3", "test4"],
ch_names=[["0"], ["0", "1"], [], ["1"]],
)
raw.set_annotations(annotations)

Expand All @@ -227,6 +228,7 @@ def test_export_edf_annotations(tmp_path):
assert_array_equal(raw.annotations.onset, raw_read.annotations.onset)
assert_array_equal(raw.annotations.duration, raw_read.annotations.duration)
assert_array_equal(raw.annotations.description, raw_read.annotations.description)
assert_array_equal(raw.annotations.ch_names, raw_read.annotations.ch_names)


@pytest.mark.skipif(
Expand Down
33 changes: 27 additions & 6 deletions mne/io/edf/edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def __init__(
)

# Read annotations from file and set it
onset, duration, desc = list(), list(), list()
onset, duration, desc, ch_names = list(), list(), list(), list()
if len(edf_info["tal_idx"]) > 0:
# Read TAL data exploiting the header info (no regexp)
idx = np.empty(0, int)
Expand All @@ -205,14 +205,18 @@ def __init__(
np.ones((len(idx), 1)),
None,
)
onset, duration, desc = _read_annotations_edf(
onset, duration, desc, ch_names = _read_annotations_edf(
tal_data[0],
encoding=encoding,
)

self.set_annotations(
Annotations(
onset=onset, duration=duration, description=desc, orig_time=None
onset=onset,
duration=duration,
description=desc,
orig_time=None,
ch_names=ch_names,
)
)

Expand Down Expand Up @@ -1950,14 +1954,31 @@ def _read_annotations_edf(annotations, encoding="utf8"):
" You might want to try setting \"encoding='latin1'\"."
) from e

events = []
events = {}
offset = 0.0
for k, ev in enumerate(triggers):
onset = float(ev[0]) + offset
duration = float(ev[2]) if ev[2] else 0
for description in ev[3].split("\x14")[1:]:
if description:
events.append([onset, duration, description])
if "@@" in description:
description, ch_name = description.split("@@")
key = f"{onset}_{duration}_{description}"
else:
ch_name = None
key = f"{onset}_{duration}_{description}"
if key in events:
key += f"_{k}" # make key unique
if key in events and ch_name:
events[key][3] += (ch_name,)
else:
events[key] = [
onset,
duration,
description,
(ch_name,) if ch_name else (),
]

elif k == 0:
# The startdate/time of a file is specified in the EDF+ header
# fields 'startdate of recording' and 'starttime of recording'.
Expand All @@ -1969,7 +1990,7 @@ def _read_annotations_edf(annotations, encoding="utf8"):
# header. If X=0, then the .X may be omitted.
offset = -onset

return zip(*events) if events else (list(), list(), list())
return zip(*events.values()) if events else (list(), list(), list(), list())


def _get_annotations_gdf(edf_info, sfreq):
Expand Down
20 changes: 14 additions & 6 deletions mne/io/edf/tests/test_edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def test_parse_annotation(tmp_path):
]
)
for tal_channel in [tal_channel_A, tal_channel_B]:
onset, duration, description = _read_annotations_edf([tal_channel])
onset, duration, description, ch_names = _read_annotations_edf([tal_channel])
assert_allclose(onset, want_onset)
assert_allclose(duration, want_duration)
assert description == want_description
Expand Down Expand Up @@ -478,18 +478,26 @@ def test_read_annot(tmp_path):
with open(annot_file, "wb") as f:
f.write(annot)

onset, duration, desc = _read_annotations_edf(annotations=str(annot_file))
onset, duration, desc, ch_names = _read_annotations_edf(annotations=str(annot_file))
annotation = Annotations(
onset=onset, duration=duration, description=desc, orig_time=None
onset=onset,
duration=duration,
description=desc,
orig_time=None,
ch_names=ch_names,
)
_assert_annotations_equal(annotation, EXPECTED_ANNOTATIONS)

# Now test when reading from buffer of data
with open(annot_file, "rb") as fid:
ch_data = np.fromfile(fid, dtype="<i2", count=len(annot))
onset, duration, desc = _read_annotations_edf([ch_data])
onset, duration, desc, ch_names = _read_annotations_edf([ch_data])
annotation = Annotations(
onset=onset, duration=duration, description=desc, orig_time=None
onset=onset,
duration=duration,
description=desc,
orig_time=None,
ch_names=ch_names,
)
_assert_annotations_equal(annotation, EXPECTED_ANNOTATIONS)

Expand Down Expand Up @@ -538,7 +546,7 @@ def test_read_latin1_annotations(tmp_path):
samp=-1,
dtype_byte=None,
)
onset, duration, description = _read_annotations_edf(
onset, duration, description, ch_names = _read_annotations_edf(
tal_channel,
encoding="latin1",
)
Expand Down

0 comments on commit 209595b

Please sign in to comment.