From 704f3e25b8fedbeda0443ff623d2c3bfde210bb1 Mon Sep 17 00:00:00 2001 From: "Lumberbot (aka Jack)" <39504233+meeseeksmachine@users.noreply.github.com> Date: Wed, 29 May 2024 21:35:59 +0200 Subject: [PATCH] Backport PR #12633 on branch maint/1.7 ([BUG] Cross-spectral density missing frequencies) (#12634) Co-authored-by: Thomas S. Binns --- doc/changes/devel/12633.bugfix.rst | 1 + mne/time_frequency/csd.py | 9 ++++++--- mne/time_frequency/tests/test_csd.py | 6 +++--- 3 files changed, 10 insertions(+), 6 deletions(-) create mode 100644 doc/changes/devel/12633.bugfix.rst diff --git a/doc/changes/devel/12633.bugfix.rst b/doc/changes/devel/12633.bugfix.rst new file mode 100644 index 00000000000..dfc69bc2fe7 --- /dev/null +++ b/doc/changes/devel/12633.bugfix.rst @@ -0,0 +1 @@ +Fix bug where :func:`mne.time_frequency.csd_multitaper`, :func:`mne.time_frequency.csd_fourier`, :func:`mne.time_frequency.csd_array_multitaper`, and :func:`mne.time_frequency.csd_array_fourier` would return cross-spectral densities with the ``fmin`` and ``fmax`` frequencies missing, by `Thomas Binns`_ \ No newline at end of file diff --git a/mne/time_frequency/csd.py b/mne/time_frequency/csd.py index e2ea5ac1ba7..f86eafc81e4 100644 --- a/mne/time_frequency/csd.py +++ b/mne/time_frequency/csd.py @@ -810,9 +810,10 @@ def csd_array_fourier( n_fft = n_times if n_fft is None else n_fft # Preparing frequencies of interest - # orig_frequencies = fftfreq(n_fft, 1. / sfreq) orig_frequencies = rfftfreq(n_fft, 1.0 / sfreq) - freq_mask = (orig_frequencies > fmin) & (orig_frequencies < fmax) + freq_mask = ( + (orig_frequencies > 0) & (orig_frequencies >= fmin) & (orig_frequencies <= fmax) + ) frequencies = orig_frequencies[freq_mask] if len(frequencies) == 0: @@ -1013,7 +1014,9 @@ def csd_array_multitaper( # Preparing frequencies of interest orig_frequencies = rfftfreq(n_fft, 1.0 / sfreq) - freq_mask = (orig_frequencies > fmin) & (orig_frequencies < fmax) + freq_mask = ( + (orig_frequencies > 0) & (orig_frequencies >= fmin) & (orig_frequencies <= fmax) + ) frequencies = orig_frequencies[freq_mask] if len(frequencies) == 0: diff --git a/mne/time_frequency/tests/test_csd.py b/mne/time_frequency/tests/test_csd.py index aefc3e2aaac..027eae6d9a2 100644 --- a/mne/time_frequency/tests/test_csd.py +++ b/mne/time_frequency/tests/test_csd.py @@ -394,15 +394,15 @@ def _test_fourier_multitaper_parameters(epochs, csd_epochs, csd_array): fmin=20, fmax=10, ) - raises(ValueError, csd_epochs, epochs, fmin=20, fmax=20.1) + raises(ValueError, csd_epochs, epochs, fmin=20.11, fmax=20.19) raises( ValueError, csd_array, epochs._data, epochs.info["sfreq"], epochs.tmin, - fmin=20, - fmax=20.1, + fmin=20.11, + fmax=20.19, ) raises(ValueError, csd_epochs, epochs, tmin=0.15, tmax=0.1) raises(