diff --git a/doc/changes/1.5.inc b/doc/changes/1.5.inc index a69f0a53f2e..20961b7ab7b 100644 --- a/doc/changes/1.5.inc +++ b/doc/changes/1.5.inc @@ -27,6 +27,7 @@ Bugs - Fix bug with :func:`mne.chpi.compute_head_pos` for CTF data where digitization points were modified in-place, producing an incorrect result during a save-load round-trip (:gh:`11934` by `Eric Larson`_) - Fix bug with notebooks when using PyVista 0.42 by implementing ``trame`` backend support (:gh:`11956` by `Eric Larson`_) - Fix bug with ``subject_info`` when loading data from and exporting to EDF file (:gh:`11952` by `Paul Roujansky`_) +- Fix handling of channel information in annotations when loading data from and exporting to EDF file (:gh:`11960` by `Paul Roujansky`_) .. _changes_1_5_0: diff --git a/mne/annotations.py b/mne/annotations.py index f8dec584868..23267d54b8d 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -1231,11 +1231,15 @@ def read_annotations(fname, sfreq="auto", uint16_codec=None): annotations = _read_annotations_eeglab(fname, uint16_codec=uint16_codec) elif name.endswith(("edf", "bdf", "gdf")): - onset, duration, description = _read_annotations_edf(fname) + onset, duration, description, ch_names = _read_annotations_edf(fname) onset = np.array(onset, dtype=float) duration = np.array(duration, dtype=float) annotations = Annotations( - onset=onset, duration=duration, description=description, orig_time=None + onset=onset, + duration=duration, + description=description, + orig_time=None, + ch_names=ch_names, ) elif name.startswith("events_") and fname.endswith("mat"): diff --git a/mne/export/_edf.py b/mne/export/_edf.py index da2acb72ff8..f0bad43e66a 100644 --- a/mne/export/_edf.py +++ b/mne/export/_edf.py @@ -264,9 +264,7 @@ def _export_raw(fname, raw, physical_range, add_ch_type): annots = raw.annotations if annots is not None: n_annotations = len(raw.annotations) - n_annot_chans = int(n_annotations / n_blocks) - if np.mod(n_annotations, n_blocks): - n_annot_chans += 1 + n_annot_chans = int(n_annotations / n_blocks) + 1 if n_annot_chans > 1: hdl.setNumberOfAnnotationSignals(n_annot_chans) @@ -305,17 +303,30 @@ def _export_raw(fname, raw, physical_range, add_ch_type): # write annotations if annots is not None: - for desc, onset, duration in zip( + for desc, onset, duration, ch_names in zip( raw.annotations.description, raw.annotations.onset, raw.annotations.duration, + raw.annotations.ch_names, ): # annotations are written in terms of 100 microseconds onset = onset * 10000 duration = duration * 10000 - if hdl.writeAnnotation(onset, duration, desc) != 0: - raise RuntimeError( - f"writeAnnotation() returned an error " - f"trying to write {desc} at {onset} " - f"for {duration} seconds." - ) + if ch_names: + for ch_name in ch_names: + if ( + hdl.writeAnnotation(onset, duration, desc + f"@@{ch_name}") + != 0 + ): + raise RuntimeError( + f"writeAnnotation() returned an error " + f"trying to write {desc}@@{ch_name} at {onset} " + f"for {duration} seconds." + ) + else: + if hdl.writeAnnotation(onset, duration, desc) != 0: + raise RuntimeError( + f"writeAnnotation() returned an error " + f"trying to write {desc} at {onset} " + f"for {duration} seconds." + ) diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index 8959ad5f84e..1cc01035589 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -215,6 +215,7 @@ def test_export_edf_annotations(tmp_path): onset=[0.01, 0.05, 0.90, 1.05], duration=[0, 1, 0, 0], description=["test1", "test2", "test3", "test4"], + ch_names=[["0"], ["0", "1"], [], ["1"]], ) raw.set_annotations(annotations) @@ -227,6 +228,7 @@ def test_export_edf_annotations(tmp_path): assert_array_equal(raw.annotations.onset, raw_read.annotations.onset) assert_array_equal(raw.annotations.duration, raw_read.annotations.duration) assert_array_equal(raw.annotations.description, raw_read.annotations.description) + assert_array_equal(raw.annotations.ch_names, raw_read.annotations.ch_names) @pytest.mark.skipif( diff --git a/mne/io/edf/edf.py b/mne/io/edf/edf.py index 40312414df8..8494f0edd94 100644 --- a/mne/io/edf/edf.py +++ b/mne/io/edf/edf.py @@ -192,7 +192,7 @@ def __init__( ) # Read annotations from file and set it - onset, duration, desc = list(), list(), list() + onset, duration, desc, ch_names = list(), list(), list(), list() if len(edf_info["tal_idx"]) > 0: # Read TAL data exploiting the header info (no regexp) idx = np.empty(0, int) @@ -205,14 +205,18 @@ def __init__( np.ones((len(idx), 1)), None, ) - onset, duration, desc = _read_annotations_edf( + onset, duration, desc, ch_names = _read_annotations_edf( tal_data[0], encoding=encoding, ) self.set_annotations( Annotations( - onset=onset, duration=duration, description=desc, orig_time=None + onset=onset, + duration=duration, + description=desc, + orig_time=None, + ch_names=ch_names, ) ) @@ -1950,14 +1954,31 @@ def _read_annotations_edf(annotations, encoding="utf8"): " You might want to try setting \"encoding='latin1'\"." ) from e - events = [] + events = {} offset = 0.0 for k, ev in enumerate(triggers): onset = float(ev[0]) + offset duration = float(ev[2]) if ev[2] else 0 for description in ev[3].split("\x14")[1:]: if description: - events.append([onset, duration, description]) + if "@@" in description: + description, ch_name = description.split("@@") + key = f"{onset}_{duration}_{description}" + else: + ch_name = None + key = f"{onset}_{duration}_{description}" + if key in events: + key += f"_{k}" # make key unique + if key in events and ch_name: + events[key][3] += (ch_name,) + else: + events[key] = [ + onset, + duration, + description, + (ch_name,) if ch_name else (), + ] + elif k == 0: # The startdate/time of a file is specified in the EDF+ header # fields 'startdate of recording' and 'starttime of recording'. @@ -1969,7 +1990,7 @@ def _read_annotations_edf(annotations, encoding="utf8"): # header. If X=0, then the .X may be omitted. offset = -onset - return zip(*events) if events else (list(), list(), list()) + return zip(*events.values()) if events else (list(), list(), list(), list()) def _get_annotations_gdf(edf_info, sfreq): diff --git a/mne/io/edf/tests/test_edf.py b/mne/io/edf/tests/test_edf.py index a93d0debcdd..124953c03e6 100644 --- a/mne/io/edf/tests/test_edf.py +++ b/mne/io/edf/tests/test_edf.py @@ -365,7 +365,7 @@ def test_parse_annotation(tmp_path): ] ) for tal_channel in [tal_channel_A, tal_channel_B]: - onset, duration, description = _read_annotations_edf([tal_channel]) + onset, duration, description, ch_names = _read_annotations_edf([tal_channel]) assert_allclose(onset, want_onset) assert_allclose(duration, want_duration) assert description == want_description @@ -478,18 +478,26 @@ def test_read_annot(tmp_path): with open(annot_file, "wb") as f: f.write(annot) - onset, duration, desc = _read_annotations_edf(annotations=str(annot_file)) + onset, duration, desc, ch_names = _read_annotations_edf(annotations=str(annot_file)) annotation = Annotations( - onset=onset, duration=duration, description=desc, orig_time=None + onset=onset, + duration=duration, + description=desc, + orig_time=None, + ch_names=ch_names, ) _assert_annotations_equal(annotation, EXPECTED_ANNOTATIONS) # Now test when reading from buffer of data with open(annot_file, "rb") as fid: ch_data = np.fromfile(fid, dtype="