diff --git a/mne/viz/tests/test_utils.py b/mne/viz/tests/test_utils.py index 8cb99d5224f..5f8932c9219 100644 --- a/mne/viz/tests/test_utils.py +++ b/mne/viz/tests/test_utils.py @@ -98,7 +98,8 @@ def test_auto_scale(): for inst in [raw, epochs]: scale_grad = 1e10 - scalings_def = dict([('eeg', 'auto'), ('grad', scale_grad)]) + scalings_def = dict([('eeg', 'auto'), ('grad', scale_grad), + ('stim', 'auto')]) # Test for wrong inputs assert_raises(ValueError, inst.plot, scalings='foo') diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 62312cebdd4..8eeb7a15fc4 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -1058,7 +1058,6 @@ def _compute_scalings(scalings, inst): A scalings dictionary with updated values """ from ..io.base import _BaseRaw - from ..io.pick import _picks_by_type from ..epochs import _BaseEpochs if not isinstance(inst, (_BaseRaw, _BaseEpochs)): raise ValueError('Must supply either Raw or Epochs') @@ -1066,11 +1065,12 @@ def _compute_scalings(scalings, inst): # If scalings is None just return it and do nothing return scalings - ch_types = _picks_by_type(inst.info) - unique_ch_types = [i_type[0] for i_type in ch_types] + ch_types = channel_indices_by_type(inst.info) + ch_types = dict([(i_type, i_ixs) + for i_type, i_ixs in ch_types.items() if len(i_ixs) != 0]) if scalings == 'auto': # If we want to auto-compute everything - scalings = dict((i_type, 'auto') for i_type in unique_ch_types) + scalings = dict((i_type, 'auto') for i_type in ch_types.keys()) if not isinstance(scalings, dict): raise ValueError('scalings must be a dictionary of ch_type: val pairs,' ' not type %s ' % type(scalings)) @@ -1096,15 +1096,13 @@ def _compute_scalings(scalings, inst): data = inst._data if isinstance(inst, _BaseEpochs): data = inst._data.reshape([len(inst.ch_names), -1]) - # Iterate through ch types and update scaling if ' auto' for key, value in scalings.items(): if value != 'auto': continue - if key not in unique_ch_types: + if key not in ch_types.keys(): raise ValueError("Sensor {0} doesn't exist in data".format(key)) - this_ixs = [i_ixs for key_, i_ixs in ch_types if key_ == key] - this_data = data[this_ixs] + this_data = data[ch_types[key]] scale_factor = np.percentile(this_data.ravel(), [0.5, 99.5]) scale_factor = np.max(np.abs(scale_factor)) scalings[key] = scale_factor