From 957556e19be1c5ee899b30065816a8e9dbe79f6c Mon Sep 17 00:00:00 2001 From: Guillaume Favelier Date: Thu, 17 Dec 2020 16:28:21 +0100 Subject: [PATCH] MRG: Plot label time course with Brain (#8335) * Follow Brain update * Upload prototype * Remove 'tube' borders * Rename variables * Update clear_glyphs() * Use _brain_color * Fix docstring * Improve support of traces and update tests * Introduce _configure_label_time_course * Add label tool bar * Add support for max mode * Refactor _update * Add support for src * Remove GFP in label picking mode * Plot time line * Fetch src when possible * Update overview table * Use only src_vol * TST: refactor _triage_stc out * Use _check_stc * Fix test_brain_init[mayavi] * Fix test_brain_traces[pyvista-mixed] * Disable annot traces for volume/mixed * Uncomment stc ValueError * Improve testing and fix label_name * Improve API * Improve testing * Improve testing * Fix add_data parameter * Improve coverage * Use dict for label_data * Prototype with show_traces * Add _read_annot_cands * Configure UI for traces mode toggle * Organize signals, clear memory and update tests * Switch to simplified interface * Exclude .ctab files * Improve coverage * Update overview table * Improve coverage * Remove cruft * Improve coverage * Improve coverage * Improve coverage * Improve coverage * Remove cruft * Improve coverage * Improve coverage * TST: Hide app window during testing * Revert "TST: Hide app window during testing" This reverts commit f5134487de96468b9d1201dd6cdefbb78aad1893. * Refactor tests * Fix default annot bug * Fix renderer bug * Test GHA skipping [skip ci] * Synchronize branch * Fix test_resolution_matrix * Add _configure_trace_mode * Improve read_annot_cands * Simplify annot mode color * Remove cruft * Add _check_stc_src * Cast to int64 * Add support for vec * Add support for mixed * Do not show trace mode for volumes * Update tests * Fix vec * Test mixed too * Update tests * Cleanup * Trigger Circle [circle front] * Cover vol as well * Update tests * Add vector to the list * Fix style * Restore label name * Exclude pca_flip * DRY * Update _add_vertex_glyph with render parameter --- mne/label.py | 3 +- mne/source_estimate.py | 51 ++-- mne/viz/_3d.py | 6 +- mne/viz/_brain/_brain.py | 387 ++++++++++++++++++++++------- mne/viz/_brain/_linkviewer.py | 16 +- mne/viz/_brain/tests/test_brain.py | 134 ++++++++-- mne/viz/backends/renderer.py | 2 - 7 files changed, 456 insertions(+), 143 deletions(-) diff --git a/mne/label.py b/mne/label.py index 1ac3d23b744..e0e6a8ccddd 100644 --- a/mne/label.py +++ b/mne/label.py @@ -1880,7 +1880,8 @@ def _read_annot_cands(dir_name): raise IOError('Directory for annotation does not exist: %s', dir_name) cands = os.listdir(dir_name) - cands = sorted(set(c.lstrip('lh.').lstrip('rh.').rstrip('.annot') + cands = sorted(set(c.replace('lh.', '').replace('rh.', '').replace( + '.annot', '') for c in cands if '.annot' in c), key=lambda x: x.lower()) # exclude .ctab files diff --git a/mne/source_estimate.py b/mne/source_estimate.py index e461511194f..152a9c62865 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -2860,6 +2860,16 @@ def _temporary_vertices(src, vertices): s['vertno'] = v +def _check_stc_src(stc, src): + if stc is not None and src is not None: + for s, v, hemi in zip(src, stc.vertices, ('left', 'right')): + n_missing = (~np.in1d(v, s['vertno'])).sum() + if n_missing: + raise ValueError('%d/%d %s hemisphere stc vertices ' + 'missing from the source space, likely ' + 'mismatch' % (n_missing, len(v), hemi)) + + def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse): """Prepare indices and flips for extract_label_time_course.""" # If src is a mixed src space, the first 2 src spaces are surf type and @@ -2872,18 +2882,8 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse): # if source estimate provided in stc, get vertices from source space and # check that they are the same as in the stcs - if stc is not None: - vertno = stc.vertices - - for s, v, hemi in zip(src, stc.vertices, ('left', 'right')): - n_missing = (~np.in1d(v, s['vertno'])).sum() - if n_missing: - raise ValueError('%d/%d %s hemisphere stc vertices missing ' - 'from the source space, likely mismatch' - % (n_missing, len(v), hemi)) - else: - vertno = [s['vertno'] for s in src] - + _check_stc_src(stc, src) + vertno = [s['vertno'] for s in src] if stc is None else stc.vertices nvert = [len(vn) for vn in vertno] # initialization @@ -3059,15 +3059,30 @@ def _dep_trans(trans): 'pass it as an argument', DeprecationWarning) +def _get_default_label_modes(): + return sorted(_label_funcs.keys()) + ['auto'] + + +def _get_allowed_label_modes(stc): + if isinstance(stc, (_BaseVolSourceEstimate, + _BaseVectorSourceEstimate)): + return ('mean', 'max', 'auto') + else: + return _get_default_label_modes() + + def _gen_extract_label_time_course(stcs, labels, src, mode='mean', allow_empty=False, trans=None, mri_resolution=True, verbose=None): # loop through source estimates and extract time series + if src is None and mode in ['mean', 'max']: + kind = 'surface' + else: + _validate_type(src, SourceSpaces) + kind = src.kind _dep_trans(trans) - _validate_type(src, SourceSpaces) - _check_option('mode', mode, sorted(_label_funcs.keys()) + ['auto']) + _check_option('mode', mode, _get_default_label_modes()) - kind = src.kind if kind in ('surface', 'mixed'): if not isinstance(labels, list): labels = [labels] @@ -3082,11 +3097,11 @@ def _gen_extract_label_time_course(stcs, labels, src, mode='mean', for si, stc in enumerate(stcs): _validate_type(stc, _BaseSourceEstimate, 'stcs[%d]' % (si,), 'source estimate') + _check_option( + 'mode', mode, _get_allowed_label_modes(stc), + 'when using a vector and/or volume source estimate') if isinstance(stc, (_BaseVolSourceEstimate, _BaseVectorSourceEstimate)): - _check_option( - 'mode', mode, ('mean', 'max', 'auto'), - 'when using a vector and/or volume source estimate') mode = 'mean' if mode == 'auto' else mode else: mode = 'mean_flip' if mode == 'auto' else mode diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index be5c11ba612..15e87488206 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -1789,7 +1789,8 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh', - https://openwetware.org/wiki/Beauchamp:FreeSurfer """ # noqa: E501 from .backends.renderer import _get_3d_backend, set_3d_backend - from ..source_estimate import _BaseSourceEstimate + from ..source_estimate import _BaseSourceEstimate, _check_stc_src + _check_stc_src(stc, src) _validate_type(stc, _BaseSourceEstimate, 'stc', 'source estimate') subjects_dir = get_subjects_dir(subjects_dir=subjects_dir, raise_error=True) @@ -1989,7 +1990,8 @@ def _plot_stc(stc, subject, surface, hemi, colormap, time_label, _check_option('time_viewer', time_viewer, (True, False, 'auto')) _validate_type(show_traces, (str, bool, 'numeric'), 'show_traces') if isinstance(show_traces, str): - _check_option('show_traces', show_traces, ('auto', 'separate'), + _check_option('show_traces', show_traces, + ('auto', 'separate', 'vertex', 'label'), extra='when a string') if time_viewer == 'auto': time_viewer = not using_mayavi diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 5afba92007f..931527245ed 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -303,16 +303,14 @@ class Brain(object): +---------------------------+--------------+---------------+ | foci | ✓ | | +---------------------------+--------------+---------------+ - | labels | ✓ | | - +---------------------------+--------------+---------------+ - | labels_dict | ✓ | | - +---------------------------+--------------+---------------+ - | remove_data | ✓ | | + | labels | ✓ | ✓ | +---------------------------+--------------+---------------+ | remove_foci | ✓ | | +---------------------------+--------------+---------------+ | remove_labels | ✓ | ✓ | +---------------------------+--------------+---------------+ + | remove_annotations | - | ✓ | + +---------------------------+--------------+---------------+ | scale_data_colormap | ✓ | | +---------------------------+--------------+---------------+ | save_image | ✓ | ✓ | @@ -335,6 +333,10 @@ class Brain(object): +---------------------------+--------------+---------------+ | flatmaps | | ✓ | +---------------------------+--------------+---------------+ + | vertex picking | | ✓ | + +---------------------------+--------------+---------------+ + | label picking | | ✓ | + +---------------------------+--------------+---------------+ """ def __init__(self, subject_id, hemi, surf, title=None, @@ -398,7 +400,10 @@ def __init__(self, subject_id, hemi, surf, title=None, self._subjects_dir = subjects_dir self._views = views self._times = None - self._label_data = {'lh': list(), 'rh': list()} + self._vertex_to_label_id = dict() + self._annotation_labels = dict() + self._labels = {'lh': list(), 'rh': list()} + self._annots = {'lh': list(), 'rh': list()} self._layered_meshes = {} # for now only one color bar can be added # since it is the same for all figures @@ -490,6 +495,8 @@ def setup_time_viewer(self, time_viewer=True, show_traces=True): """ if self.time_viewer: return + if not self._data: + raise ValueError("No data to visualize. See ``add_data``.") self.time_viewer = time_viewer self.orientation = list(_lh_views_dict.keys()) self.default_smoothing_range = [0, 15] @@ -507,12 +514,25 @@ def setup_time_viewer(self, time_viewer=True, show_traces=True): self.default_playback_speed_range = [0.01, 1] self.default_playback_speed_value = 0.05 self.default_status_bar_msg = "Press ? for help" + self.default_label_extract_modes = { + "stc": ["mean", "max"], + "src": ["mean_flip", "pca_flip", "auto"], + } + self.default_trace_modes = ('vertex', 'label') + self.annot = None + self.label_extract_mode = None all_keys = ('lh', 'rh', 'vol') self.act_data_smooth = {key: (None, None) for key in all_keys} - self.color_cycle = None + self.color_list = _get_color_list() + # remove grey for better contrast on the brain + self.color_list.remove("#7f7f7f") + self.color_cycle = _ReuseCycle(self.color_list) self.mpl_canvas = None + self.gfp = None + self.picked_patches = {key: list() for key in all_keys} self.picked_points = {key: list() for key in all_keys} self.pick_table = dict() + self._spheres = list() self._mouse_no_mvt = -1 self.icons = dict() self.actions = dict() @@ -524,6 +544,9 @@ def setup_time_viewer(self, time_viewer=True, show_traces=True): self.slider_color = (0.43137255, 0.44313725, 0.45882353) self.slider_tube_width = 0.04 self.slider_tube_color = (0.69803922, 0.70196078, 0.70980392) + self._trace_mode_widget = None + self._annot_cands_widget = None + self._label_mode_widget = None # Direct access parameters: self._iren = self._renderer.plotter.iren @@ -537,9 +560,15 @@ def setup_time_viewer(self, time_viewer=True, show_traces=True): _validate_type(show_traces, (bool, str, 'numeric'), 'show_traces') self.interactor_fraction = 0.25 if isinstance(show_traces, str): - assert show_traces == 'separate' # should be guaranteed earlier self.show_traces = True - self.separate_canvas = True + self.separate_canvas = False + self.traces_mode = 'vertex' + if show_traces == 'separate': + self.separate_canvas = True + elif show_traces == 'label': + self.traces_mode = 'label' + else: + assert show_traces == 'vertex' # guaranteed above else: if isinstance(show_traces, bool): self.show_traces = show_traces @@ -551,19 +580,20 @@ def setup_time_viewer(self, time_viewer=True, show_traces=True): f'got {show_traces}') self.show_traces = True self.interactor_fraction = show_traces + self.traces_mode = 'vertex' self.separate_canvas = False del show_traces - self._spheres = list() self._load_icons() self._configure_time_label() self._configure_sliders() self._configure_scalar_bar() self._configure_playback() - self._configure_point_picking() self._configure_menu() self._configure_tool_bar() self._configure_status_bar() + self._configure_picking() + self._configure_trace_mode() # show everything at the end self.toggle_interface() @@ -573,7 +603,8 @@ def setup_time_viewer(self, time_viewer=True, show_traces=True): @safe_event def _clean(self): # resolve the reference cycle - self.clear_points() + self.clear_glyphs() + self.remove_annotations() # clear init actors for hemi in self._hemis: self._layered_meshes[hemi]._clean() @@ -966,12 +997,7 @@ def _configure_sliders(self): def _configure_playback(self): self.plotter.add_callback(self._play, self.refresh_rate_ms) - def _configure_point_picking(self): - if not self.show_traces: - return - from ..backends._pyvista import _update_picking_callback - # use a matplotlib canvas - self.color_cycle = _ReuseCycle(_get_color_list()) + def _configure_mplcanvas(self): win = self.plotter.app_window dpi = win.windowHandle().screen().logicalDotsPerInch() ratio = (1 - self.interactor_fraction) / self.interactor_fraction @@ -1001,27 +1027,19 @@ def _configure_point_picking(self): ) self.mpl_canvas.show() - # get data for each hemi - for idx, hemi in enumerate(['vol', 'lh', 'rh']): - hemi_data = self._data.get(hemi) - if hemi_data is not None: - act_data = hemi_data['array'] - if act_data.ndim == 3: - act_data = np.linalg.norm(act_data, axis=1) - smooth_mat = hemi_data.get('smooth_mat') - vertices = hemi_data['vertices'] - if hemi == 'vol': - assert smooth_mat is None - smooth_mat = sparse.csr_matrix( - (np.ones(len(vertices)), - (vertices, np.arange(len(vertices))))) - self.act_data_smooth[hemi] = (act_data, smooth_mat) + def _configure_vertex_time_course(self): + if not self.show_traces: + return + if self.mpl_canvas is None: + self._configure_mplcanvas() + else: + self.clear_glyphs() # plot the GFP y = np.concatenate(list(v[0] for v in self.act_data_smooth.values() if v[0] is not None)) y = np.linalg.norm(y, axis=0) / np.sqrt(len(y)) - self.mpl_canvas.axes.plot( + self.gfp, = self.mpl_canvas.axes.plot( self._data['time'], y, lw=3, label='GFP', zorder=3, color=self._fg_color, alpha=0.5, ls=':') @@ -1056,7 +1074,26 @@ def _configure_point_picking(self): else: mesh = self._layered_meshes[hemi]._polydata vertex_id = vertices[ind[0]] - self.add_point(hemi, mesh, vertex_id) + self._add_vertex_glyph(hemi, mesh, vertex_id) + + def _configure_picking(self): + from ..backends._pyvista import _update_picking_callback + + # get data for each hemi + for idx, hemi in enumerate(['vol', 'lh', 'rh']): + hemi_data = self._data.get(hemi) + if hemi_data is not None: + act_data = hemi_data['array'] + if act_data.ndim == 3: + act_data = np.linalg.norm(act_data, axis=1) + smooth_mat = hemi_data.get('smooth_mat') + vertices = hemi_data['vertices'] + if hemi == 'vol': + assert smooth_mat is None + smooth_mat = sparse.csr_matrix( + (np.ones(len(vertices)), + (vertices, np.arange(len(vertices))))) + self.act_data_smooth[hemi] = (act_data, smooth_mat) _update_picking_callback( self.plotter, @@ -1066,6 +1103,85 @@ def _configure_point_picking(self): self._on_pick ) + def _configure_trace_mode(self): + from ...source_estimate import _get_allowed_label_modes + from ...label import _read_annot_cands + from PyQt5.QtWidgets import QComboBox, QLabel + if not self.show_traces: + return + + # do not show trace mode for volumes + if (self._data.get('src', None) is not None and + self._data['src'].kind == 'volume'): + self._configure_vertex_time_course() + return + + # setup candidate annots + def _set_annot(annot): + self.clear_glyphs() + self.remove_labels() + self.remove_annotations() + self.annot = annot + + if annot == 'None': + self.traces_mode = 'vertex' + self._configure_vertex_time_course() + else: + self.traces_mode = 'label' + self._configure_label_time_course() + self._update() + + dir_name = op.join(self._subjects_dir, self._subject_id, 'label') + cands = _read_annot_cands(dir_name) + self.tool_bar.addSeparator() + self.tool_bar.addWidget(QLabel("Annotation")) + self._annot_cands_widget = QComboBox() + self.tool_bar.addWidget(self._annot_cands_widget) + self._annot_cands_widget.addItem('None') + for cand in cands: + self._annot_cands_widget.addItem(cand) + self.annot = cands[0] + + # setup label extraction parameters + def _set_label_mode(mode): + if self.traces_mode != 'label': + return + import copy + glyphs = copy.deepcopy(self.picked_patches) + self.label_extract_mode = mode + self.clear_glyphs() + for hemi in self._hemis: + for label_id in glyphs[hemi]: + label = self._annotation_labels[hemi][label_id] + vertex_id = label.vertices[0] + self._add_label_glyph(hemi, None, vertex_id) + self.mpl_canvas.axes.relim() + self.mpl_canvas.axes.autoscale_view() + self.mpl_canvas.update_plot() + self._update() + + self.tool_bar.addSeparator() + self.tool_bar.addWidget(QLabel("Label extraction mode")) + self._label_mode_widget = QComboBox() + self.tool_bar.addWidget(self._label_mode_widget) + stc = self._data["stc"] + modes = _get_allowed_label_modes(stc) + if self._data["src"] is None: + modes = [m for m in modes if m not in + self.default_label_extract_modes["src"]] + for mode in modes: + self._label_mode_widget.addItem(mode) + self.label_extract_mode = mode + + if self.traces_mode == 'vertex': + _set_annot('None') + else: + _set_annot(self.annot) + self._annot_cands_widget.setCurrentText(self.annot) + self._label_mode_widget.setCurrentText(self.label_extract_mode) + self._annot_cands_widget.currentTextChanged.connect(_set_annot) + self._label_mode_widget.currentTextChanged.connect(_set_label_mode) + def _load_icons(self): from PyQt5.QtGui import QIcon from ..backends._pyvista import _init_resources @@ -1124,7 +1240,7 @@ def _configure_tool_bar(self): self.actions["clear"] = self.tool_bar.addAction( self.icons["clear"], "Clear traces", - self.clear_points + self.clear_glyphs ) self.actions["help"] = self.tool_bar.addAction( self.icons["help"], @@ -1178,6 +1294,9 @@ def _on_button_release(self, vtk_picker, event): self._mouse_no_mvt = 0 def _on_pick(self, vtk_picker, event): + if not self.show_traces: + return + # vtk_picker is a vtkCellPicker cell_id = vtk_picker.GetCellId() mesh = vtk_picker.GetDataSet() @@ -1198,12 +1317,12 @@ def _on_pick(self, vtk_picker, event): if found_sphere is not None: break if found_sphere is not None: - assert found_sphere._is_point + assert found_sphere._is_glyph mesh = found_sphere # 2) Remove sphere if it's what we have - if hasattr(mesh, "_is_point"): - self.remove_point(mesh) + if hasattr(mesh, "_is_glyph"): + self._remove_vertex_glyph(mesh) return # 3) Otherwise, pick the objects in the scene @@ -1258,26 +1377,38 @@ def _on_pick(self, vtk_picker, event): idx = np.argmin(abs(vertices - pos), axis=0) vertex_id = cell[idx[0]] - if vertex_id not in self.picked_points[hemi]: - self.add_point(hemi, mesh, vertex_id) + if self.traces_mode == 'label': + self._add_label_glyph(hemi, mesh, vertex_id) + else: + self._add_vertex_glyph(hemi, mesh, vertex_id) - def add_point(self, hemi, mesh, vertex_id): - """Pick a vertex on the brain. + def _add_label_glyph(self, hemi, mesh, vertex_id): + if hemi == 'vol': + return + label_id = self._vertex_to_label_id[hemi][vertex_id] + label = self._annotation_labels[hemi][label_id] - Parameters - ---------- - hemi : str - The hemisphere id of the vertex. - mesh : object - The mesh where picking is expected. - vertex_id : int - The vertex identifier in the mesh. + # remove the patch if already picked + if label_id in self.picked_patches[hemi]: + self._remove_label_glyph(hemi, label_id) + return + + if hemi == label.hemi: + self.add_label(label, borders=True, reset_camera=False) + self.picked_patches[hemi].append(label_id) + + def _remove_label_glyph(self, hemi, label_id): + label = self._annotation_labels[hemi][label_id] + label._line.remove() + self.color_cycle.restore(label._color) + self.mpl_canvas.update_plot() + self._layered_meshes[hemi].remove_overlay(label.name) + self.picked_patches[hemi].remove(label_id) + + def _add_vertex_glyph(self, hemi, mesh, vertex_id): + if vertex_id in self.picked_points[hemi]: + return - Returns - ------- - sphere : object - The glyph created for the picked point. - """ # skip if the wrong hemi is selected if self.act_data_smooth[hemi][0] is None: return @@ -1326,7 +1457,7 @@ def add_point(self, hemi, mesh, vertex_id): # add metadata for picking for sphere in spheres: - sphere._is_point = True + sphere._is_glyph = True sphere._hemi = hemi sphere._line = line sphere._actors = actors @@ -1338,14 +1469,7 @@ def add_point(self, hemi, mesh, vertex_id): self.pick_table[vertex_id] = spheres return sphere - def remove_point(self, mesh): - """Remove the picked point from its glyph. - - Parameters - ---------- - mesh : object - The mesh associated to the point to remove. - """ + def _remove_vertex_glyph(self, mesh, render=True): vertex_id = mesh._vertex_id if vertex_id not in self.pick_table: return @@ -1364,20 +1488,28 @@ def remove_point(self, mesh): self.color_cycle.restore(color) for sphere in spheres: # remove all actors - self.plotter.remove_actor(sphere._actors) + self.plotter.remove_actor(sphere._actors, render=render) sphere._actors = None self._spheres.pop(self._spheres.index(sphere)) self.pick_table.pop(vertex_id) - def clear_points(self): - """Clear the picked points.""" - if not hasattr(self, '_spheres'): + def clear_glyphs(self): + """Clear the picking glyphs.""" + if not self.time_viewer: return for sphere in list(self._spheres): # will remove itself, so copy - self.remove_point(sphere) + self._remove_vertex_glyph(sphere, render=False) assert sum(len(v) for v in self.picked_points.values()) == 0 assert len(self.pick_table) == 0 assert len(self._spheres) == 0 + for hemi in self._hemis: + for label_id in list(self.picked_patches[hemi]): + self._remove_label_glyph(hemi, label_id) + assert sum(len(v) for v in self.picked_patches.values()) == 0 + if self.gfp is not None: + self.gfp.remove() + self.gfp = None + self._update() def plot_time_course(self, hemi, vertex_id, color): """Plot the vertex time course. @@ -1701,8 +1833,6 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, self._data['transparent'] = transparent # data specific for a hemi self._data[hemi] = dict() - self._data[hemi]['actors'] = None - self._data[hemi]['mesh'] = None self._data[hemi]['glyph_dataset'] = None self._data[hemi]['glyph_mapper'] = None self._data[hemi]['glyph_actor'] = None @@ -1791,8 +1921,16 @@ def remove_labels(self): """Remove all the ROI labels from the image.""" for hemi in self._hemis: mesh = self._layered_meshes[hemi] - mesh.remove_overlay(self._label_data[hemi]) - self._label_data[hemi].clear() + mesh.remove_overlay(self._labels[hemi]) + self._labels[hemi].clear() + self._update() + + def remove_annotations(self): + """Remove all annotations from the image.""" + for hemi in self._hemis: + mesh = self._layered_meshes[hemi] + mesh.remove_overlay(self._annots[hemi]) + self._annots[hemi].clear() self._update() def _add_volume_data(self, hemi, src, volume_options): @@ -1912,7 +2050,8 @@ def _add_volume_data(self, hemi, src, volume_options): return actor_pos, actor_neg def add_label(self, label, color=None, alpha=1, scalar_thresh=None, - borders=False, hemi=None, subdir=None): + borders=False, hemi=None, subdir=None, + reset_camera=True): """Add an ROI label to the image. Parameters @@ -1943,6 +2082,9 @@ def add_label(self, label, color=None, alpha=1, scalar_thresh=None, label directory rather than in the label directory itself (e.g. for ``$SUBJECTS_DIR/$SUBJECT/label/aparc/lh.cuneus.label`` ``brain.add_label('cuneus', subdir='aparc')``). + reset_camera : bool + If True, reset the camera view after adding the label. Defaults + to True. Notes ----- @@ -2004,11 +2146,23 @@ def add_label(self, label, color=None, alpha=1, scalar_thresh=None, if scalar_thresh is not None: ids = ids[scalars >= scalar_thresh] - # XXX: add support for label_name - self._label_name = label_name + scalars = np.zeros(self.geo[hemi].coords.shape[0]) + scalars[ids] = 1 + + if self.time_viewer and self.show_traces: + stc = self._data["stc"] + src = self._data["src"] + tc = stc.extract_label_time_course(label, src=src, + mode=self.label_extract_mode) + tc = tc[0] if tc.ndim == 2 else tc[0, 0, :] + color = next(self.color_cycle) + line = self.mpl_canvas.plot( + self._data['time'], tc, label=label_name, + color=color) + else: + line = None - label = np.zeros(self.geo[hemi].coords.shape[0]) - label[ids] = 1 + orig_color = color color = colorConverter.to_rgba(color, alpha) cmap = np.array([(0, 0, 0, 0,), color]) ctable = np.round(cmap * 255).astype(np.uint8) @@ -2016,11 +2170,11 @@ def add_label(self, label, color=None, alpha=1, scalar_thresh=None, for ri, ci, v in self._iter_views(hemi): self._renderer.subplot(ri, ci) if borders: - n_vertices = label.size + n_vertices = scalars.size edges = mesh_edges(self.geo[hemi].faces) edges = edges.tocoo() - border_edges = label[edges.row] != label[edges.col] - show = np.zeros(n_vertices, dtype=np.int) + border_edges = scalars[edges.row] != scalars[edges.col] + show = np.zeros(n_vertices, dtype=np.int64) keep_idx = np.unique(edges.row[border_edges]) if isinstance(borders, int): for _ in range(borders): @@ -2031,19 +2185,22 @@ def add_label(self, label, color=None, alpha=1, scalar_thresh=None, keep_idx, axis=1)] keep_idx = np.unique(keep_idx) show[keep_idx] = 1 - label *= show + scalars *= show mesh = self._layered_meshes[hemi] mesh.add_overlay( - scalars=label, + scalars=scalars, colormap=ctable, rng=None, opacity=alpha, name=label_name, ) - self._label_data[hemi].append(label_name) - self._renderer.set_camera(**views_dicts[hemi][v]) - + if reset_camera: + self._renderer.set_camera(**views_dicts[hemi][v]) + if self.time_viewer and self.traces_mode == 'label': + label._color = orig_color + label._line = line + self._labels[hemi].append(label) self._update() def add_foci(self, coords, coords_as_verts=False, map_surface=None, @@ -2140,6 +2297,34 @@ def add_text(self, x, y, text, name=None, color=None, opacity=1.0, self._renderer.text2d(x_window=x, y_window=y, text=text, color=color, size=font_size, justification=justification) + def _configure_label_time_course(self): + from ...label import read_labels_from_annot + if not self.show_traces: + return + if self.mpl_canvas is None: + self._configure_mplcanvas() + else: + self.clear_glyphs() + self.traces_mode = 'label' + self.add_annotation(self.annot, color="w", alpha=0.75) + + # now plot the time line + self.plot_time_line() + self.mpl_canvas.update_plot() + + for hemi in self._hemis: + labels = read_labels_from_annot( + subject=self._subject_id, + parc=self.annot, + hemi=hemi, + subjects_dir=self._subjects_dir + ) + self._vertex_to_label_id[hemi] = np.full( + self.geo[hemi].coords.shape[0], -1) + self._annotation_labels[hemi] = labels + for idx, label in enumerate(labels): + self._vertex_to_label_id[hemi][label.vertices] = idx + def add_annotation(self, annot, borders=True, alpha=1, hemi=None, remove_existing=True, color=None, **kwargs): """Add an annotation file. @@ -2213,7 +2398,6 @@ def add_annotation(self, annot, borders=True, alpha=1, hemi=None, annot = 'annotation' for hemi, (labels, cmap) in zip(hemis, annots): - # Maybe zero-out the non-border vertices self._to_borders(labels, hemi, borders) @@ -2243,14 +2427,21 @@ def add_annotation(self, annot, borders=True, alpha=1, hemi=None, cmap[:, :3] = rgb.astype(cmap.dtype) ctable = cmap.astype(np.float64) - mesh = self._layered_meshes[hemi] - mesh.add_overlay( - scalars=ids, - colormap=ctable, - rng=[np.min(ids), np.max(ids)], - opacity=alpha, - name=annot, - ) + for ri, ci, _ in self._iter_views(hemi): + self._renderer.subplot(ri, ci) + mesh = self._layered_meshes[hemi] + mesh.add_overlay( + scalars=ids, + colormap=ctable, + rng=[np.min(ids), np.max(ids)], + opacity=alpha, + name=annot, + ) + self._annots[hemi].append(annot) + if not self.time_viewer or self.traces_mode == 'vertex': + from ..backends._pyvista import _set_colormap_range + _set_colormap_range(mesh._actor, cmap.astype(np.uint8), + None) self._update() @@ -2692,6 +2883,10 @@ def data(self): """Data used by time viewer and color bar widgets.""" return self._data + @property + def labels(self): + return self._labels + @property def views(self): return self._views diff --git a/mne/viz/_brain/_linkviewer.py b/mne/viz/_brain/_linkviewer.py index b84c07e1680..14620f8e920 100644 --- a/mne/viz/_brain/_linkviewer.py +++ b/mne/viz/_brain/_linkviewer.py @@ -57,12 +57,12 @@ def _time_func(*args, **kwargs): if picking: def _func_add(*args, **kwargs): for brain in self.brains: - brain._add_point(*args, **kwargs) + brain._add_vertex_glyph2(*args, **kwargs) brain.plotter.update() def _func_remove(*args, **kwargs): for brain in self.brains: - brain._remove_point(*args, **kwargs) + brain._remove_vertex_glyph2(*args, **kwargs) # save initial picked points initial_points = dict() @@ -74,18 +74,18 @@ def _func_remove(*args, **kwargs): # link the viewers for brain in self.brains: - brain.clear_points() - brain._add_point = brain.add_point - brain.add_point = _func_add - brain._remove_point = brain.remove_point - brain.remove_point = _func_remove + brain.clear_glyphs() + brain._add_vertex_glyph2 = brain._add_vertex_glyph + brain._add_vertex_glyph = _func_add + brain._remove_vertex_glyph2 = brain._remove_vertex_glyph + brain._remove_vertex_glyph = _func_remove # link the initial points for hemi in initial_points.keys(): if hemi in brain._layered_meshes: mesh = brain._layered_meshes[hemi]._polydata for vertex_id in initial_points[hemi]: - self.leader.add_point(hemi, mesh, vertex_id) + self.leader._add_vertex_glyph(hemi, mesh, vertex_id) if colorbar: fmin = self.leader._data["fmin"] diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 93a1cd13684..6a0897bf12b 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -15,16 +15,21 @@ import numpy as np from numpy.testing import assert_allclose, assert_array_equal -from mne import (read_source_estimate, SourceEstimate, MixedSourceEstimate, +from mne import (read_source_estimate, read_evokeds, read_cov, + read_forward_solution, pick_types_forward, + SourceEstimate, MixedSourceEstimate, VolSourceEstimate) +from mne.minimum_norm import apply_inverse, make_inverse_operator from mne.source_space import (read_source_spaces, vertex_to_mni, setup_volume_source_space) from mne.datasets import testing from mne.utils import check_version +from mne.label import read_label from mne.viz._brain import Brain, _LinkViewer, _BrainScraper, _LayeredMesh from mne.viz._brain.colormap import calculate_lut from matplotlib import cm, image +from matplotlib.lines import Line2D import matplotlib.pyplot as plt data_path = testing.data_path(download=False) @@ -32,6 +37,12 @@ subjects_dir = path.join(data_path, 'subjects') fname_stc = path.join(data_path, 'MEG/sample/sample_audvis_trunc-meg') fname_label = path.join(data_path, 'MEG/sample/labels/Vis-lh.label') +fname_cov = path.join( + data_path, 'MEG', 'sample', 'sample_audvis_trunc-cov.fif') +fname_evoked = path.join(data_path, 'MEG', 'sample', + 'sample_audvis_trunc-ave.fif') +fname_fwd = path.join( + data_path, 'MEG', 'sample', 'sample_audvis_trunc-meg-eeg-oct-4-fwd.fif') src_fname = path.join(data_path, 'subjects', 'sample', 'bem', 'sample-oct-6-src.fif') @@ -138,7 +149,6 @@ def test_brain_init(renderer, tmpdir, pixel_ratio, brain_gc): """Test initialization of the Brain instance.""" if renderer._get_3d_backend() != 'pyvista': pytest.skip('TimeViewer tests only supported on PyVista') - from mne.label import read_label from mne.source_estimate import _BaseSourceEstimate class FakeSTC(_BaseSourceEstimate): @@ -167,6 +177,8 @@ def __init__(self): cortex=cortex, units='m', **kwargs) with pytest.raises(TypeError, match='not supported'): brain._check_stc(hemi='lh', array=FakeSTC(), vertices=None) + with pytest.raises(ValueError, match='add_data'): + brain.setup_time_viewer(time_viewer=True) brain._hemi = 'foo' # for testing: hemis with pytest.raises(ValueError, match='not be None'): brain._check_hemi(hemi=None) @@ -258,6 +270,8 @@ def __init__(self): brain.add_label('foo', subdir='bar') label.name = None # test unnamed label brain.add_label(label, scalar_thresh=0.) + assert isinstance(brain.labels[label.hemi], list) + assert 'unnamed' in brain._layered_meshes[label.hemi]._overlays brain.remove_labels() brain.add_label(fname_label) brain.add_label('V1', borders=True) @@ -278,6 +292,13 @@ def __init__(self): borders = [True, 2] alphas = [1, 0.5] colors = [None, 'r'] + brain = Brain(subject_id='fsaverage', hemi='both', size=size, + surf='inflated', subjects_dir=subjects_dir) + with pytest.raises(RuntimeError, match="both hemispheres"): + brain.add_annotation(annots[-1]) + with pytest.raises(ValueError, match="does not exist"): + brain.add_annotation('foo') + brain.close() brain = Brain(subject_id='fsaverage', hemi=hemi, size=size, surf='inflated', subjects_dir=subjects_dir) for a, b, p, color in zip(annots, borders, alphas, colors): @@ -360,6 +381,11 @@ def test_brain_time_viewer(renderer_interactive, pixel_ratio, brain_gc): _create_testing_brain(hemi='lh', surf='white', src='volume', volume_options={'foo': 'bar'}) brain = _create_testing_brain(hemi='both', show_traces=False) + # test sub routines when show_traces=False + brain._on_pick(None, None) + brain._configure_vertex_time_course() + brain._configure_label_time_course() + brain.setup_time_viewer() # for coverage brain.callbacks["time"](value=0) brain.callbacks["orientation_lh_0_0"]( value='lat', @@ -412,6 +438,7 @@ def test_brain_time_viewer(renderer_interactive, pixel_ratio, brain_gc): ]) @pytest.mark.parametrize('src', [ 'surface', + pytest.param('vector', marks=pytest.mark.slowtest), pytest.param('volume', marks=pytest.mark.slowtest), pytest.param('mixed', marks=pytest.mark.slowtest), ]) @@ -421,21 +448,84 @@ def test_brain_traces(renderer_interactive, hemi, src, tmpdir, """Test brain traces.""" if renderer_interactive._get_3d_backend() != 'pyvista': pytest.skip('Only PyVista supports traces') + + hemi_str = list() + if src in ('surface', 'vector', 'mixed'): + hemi_str.extend([hemi] if hemi in ('lh', 'rh') else ['lh', 'rh']) + if src in ('mixed', 'volume'): + hemi_str.extend(['vol']) + + # label traces + brain = _create_testing_brain( + hemi=hemi, surf='white', src=src, show_traces='label', + volume_options=None, # for speed, don't upsample + n_time=5, initial_time=0, + ) + if src == 'surface': + brain._data['src'] = None # test src=None + if src in ('surface', 'vector', 'mixed'): + assert brain.show_traces + assert brain.traces_mode == 'label' + brain._label_mode_widget.setCurrentText('max') + + # test picking a cell at random + rng = np.random.RandomState(0) + for idx, current_hemi in enumerate(hemi_str): + if current_hemi == 'vol': + continue + current_mesh = brain._layered_meshes[current_hemi]._polydata + cell_id = rng.randint(0, current_mesh.n_cells) + test_picker = TstVTKPicker( + current_mesh, cell_id, current_hemi, brain) + assert len(brain.picked_patches[current_hemi]) == 0 + brain._on_pick(test_picker, None) + assert len(brain.picked_patches[current_hemi]) == 1 + for label_id in list(brain.picked_patches[current_hemi]): + label = brain._annotation_labels[current_hemi][label_id] + assert isinstance(label._line, Line2D) + brain._label_mode_widget.setCurrentText('mean') + brain.clear_glyphs() + assert len(brain.picked_patches[current_hemi]) == 0 + brain._on_pick(test_picker, None) # picked and added + assert len(brain.picked_patches[current_hemi]) == 1 + brain._on_pick(test_picker, None) # picked again so removed + assert len(brain.picked_patches[current_hemi]) == 0 + # test switching from 'label' to 'vertex' + brain._annot_cands_widget.setCurrentText('None') + brain._label_mode_widget.setCurrentText('max') + else: # volume + assert brain._trace_mode_widget is None + assert brain._annot_cands_widget is None + assert brain._label_mode_widget is None + brain.close() + + # test colormap + if src != 'vector': + brain = _create_testing_brain( + hemi=hemi, surf='white', src=src, show_traces=0.5, initial_time=0, + volume_options=None, # for speed, don't upsample + n_time=1 if src == 'mixed' else 5, diverging=True, + add_data_kwargs=dict(colorbar_kwargs=dict(n_labels=3)), + ) + # mne_analyze should be chosen + ctab = brain._data['ctable'] + assert_array_equal(ctab[0], [0, 255, 255, 255]) # opaque cyan + assert_array_equal(ctab[-1], [255, 255, 0, 255]) # opaque yellow + assert_allclose(ctab[len(ctab) // 2], [128, 128, 128, 0], atol=3) + brain.close() + + # vertex traces brain = _create_testing_brain( hemi=hemi, surf='white', src=src, show_traces=0.5, initial_time=0, volume_options=None, # for speed, don't upsample - n_time=1 if src == 'mixed' else 5, diverging=True, + n_time=1 if src == 'mixed' else 5, add_data_kwargs=dict(colorbar_kwargs=dict(n_labels=3)), ) assert brain.show_traces + assert brain.traces_mode == 'vertex' assert hasattr(brain, "picked_points") assert hasattr(brain, "_spheres") assert brain.plotter.scalar_bar.GetNumberOfLabels() == 3 - # mne_analyze should be chosen - ctab = brain._data['ctable'] - assert_array_equal(ctab[0], [0, 255, 255, 255]) # opaque cyan - assert_array_equal(ctab[-1], [255, 255, 0, 255]) # opaque yellow - assert_allclose(ctab[len(ctab) // 2], [128, 128, 128, 0], atol=3) # add foci should work for volumes brain.add_foci([[0, 0, 0]], hemi='lh' if src == 'surface' else 'vol') @@ -443,11 +533,6 @@ def test_brain_traces(renderer_interactive, hemi, src, tmpdir, # test points picked by default picked_points = brain.get_picked_points() spheres = brain._spheres - hemi_str = list() - if src in ('surface', 'mixed'): - hemi_str.extend([hemi] if hemi in ('lh', 'rh') else ['lh', 'rh']) - if src in ('mixed', 'volume'): - hemi_str.extend(['vol']) for current_hemi in hemi_str: assert len(picked_points[current_hemi]) == 1 n_spheres = len(hemi_str) @@ -455,8 +540,12 @@ def test_brain_traces(renderer_interactive, hemi, src, tmpdir, n_spheres += 1 assert len(spheres) == n_spheres + # test switching from 'vertex' to 'label' + if src == 'surface': + brain._annot_cands_widget.setCurrentText('aparc') + brain._annot_cands_widget.setCurrentText('None') # test removing points - brain.clear_points() + brain.clear_glyphs() assert len(spheres) == 0 for key in ('lh', 'rh', 'vol'): assert len(picked_points[key]) == 0 @@ -480,6 +569,7 @@ def test_brain_traces(renderer_interactive, hemi, src, tmpdir, assert cell_id == test_picker.cell_id assert test_picker.point_id is None brain._on_pick(test_picker, None) + brain._on_pick(test_picker, None) assert test_picker.point_id is not None assert len(picked_points[current_hemi]) == 1 assert picked_points[current_hemi][0] == test_picker.point_id @@ -563,7 +653,7 @@ def test_brain_linkviewer(renderer_interactive, brain_gc): picking=False, ) - brain_data = _create_testing_brain(hemi='split', show_traces=True) + brain_data = _create_testing_brain(hemi='split', show_traces='vertex') link_viewer = _LinkViewer( [brain2, brain_data], time=True, @@ -691,11 +781,23 @@ def test_calculate_lut(): def _create_testing_brain(hemi, surf='inflated', src='surface', size=300, n_time=5, diverging=False, **kwargs): - assert src in ('surface', 'mixed', 'volume') + assert src in ('surface', 'vector', 'mixed', 'volume') meth = 'plot' if src in ('surface', 'mixed'): sample_src = read_source_spaces(src_fname) klass = MixedSourceEstimate if src == 'mixed' else SourceEstimate + if src == 'vector': + fwd = read_forward_solution(fname_fwd) + fwd = pick_types_forward(fwd, meg=True, eeg=False) + evoked = read_evokeds(fname_evoked, baseline=(None, 0))[0] + noise_cov = read_cov(fname_cov) + free = make_inverse_operator( + evoked.info, fwd, noise_cov, loose=1.) + stc = apply_inverse(evoked, free, pick_ori='vector') + return stc.plot( + subject=subject_id, hemi=hemi, size=size, + subjects_dir=subjects_dir, colormap='auto', + **kwargs) if src in ('volume', 'mixed'): vol_src = setup_volume_source_space( subject_id, 7., mri='aseg.mgz', diff --git a/mne/viz/backends/renderer.py b/mne/viz/backends/renderer.py index 680a0574a55..fad9f882f97 100644 --- a/mne/viz/backends/renderer.py +++ b/mne/viz/backends/renderer.py @@ -98,8 +98,6 @@ def set_3d_backend(backend_name, verbose=None): +--------------------------------------+--------+---------+ | Save offline movie | ✓ | ✓ | +--------------------------------------+--------+---------+ - | Point picking | | ✓ | - +--------------------------------------+--------+---------+ .. note:: In the case of `plot_vector_source_estimates` with PyVista, the glyph