Skip to content

Commit 9cb6648

Browse files
GuillaumeFavelieragramfort
authored andcommitted
MRG: Use _TimeViewer in plot_source_estimates (mne-tools#7153)
* Add basic version of smoothing_steps support * Update naming * Add support for hemi='both' * Remove unnecessary update * Add a second test slider * Fix returned parameter error * Add support for set_time_point * Fix test_plot_sparse_source_estimates() * Fix style * Improve coverage * Improve coverage * Update time label * Fix time_actor bug * Change default verbose parameter * Add shortcut to toggle_interface * Modify time slider to realtime * Fix segfault * Use latest dev version * Add experimental orientation slider * Fix style * Rework interface * Add some comments * Add prototype for colorbar sliders * Refactor orientation label visibility * Modify time slider to be horizontal * Switch to temporary interface * Update colorbar points * Add docs * Improve syntax * Use local import * Improve stability * Update feature overview table * Isolate the FutureWarning * Rework variable naming * Fix backward incompatibility * Update and trigger plot_visualize_stc.py * Remove report * TST: Disable offscreen rendering to allow widgets interactivity * TST: Try hotfix * Revert "TST: Try hotfix" This reverts commit 15193d1. * Connect scalar change to the colorbar * Change default values for smoothing * Hide time index * Setup IntSlider * Refresh UI * Use round function in custom sliders * Fix mesh scalar range * Manage when time or scalarbar is unavailable * Refresh UI * Add fscale slider * TST: Try to connect fscale to fmin/fmid/fmax * Fix docstring * Find a better range * Make better use of space for UI
1 parent 1ec8f80 commit 9cb6648

File tree

9 files changed

+477
-32
lines changed

9 files changed

+477
-32
lines changed

environment.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ dependencies:
3333
- mne
3434
- https://api.github.com/repos/numpy/numpydoc/zipball/master
3535
- vtk
36-
- pyvista>=0.23.0
36+
- https://api.github.com/repos/pyvista/pyvista/zipball/master
3737
- mayavi
3838
- PySurfer[save_movie]
3939
- dipy --only-binary dipy

mne/viz/_3d.py

+1
Original file line numberDiff line numberDiff line change
@@ -1675,6 +1675,7 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh',
16751675
from surfer import Brain, TimeViewer
16761676
else:
16771677
from ._brain import _Brain as Brain
1678+
from ._brain import _TimeViewer as TimeViewer
16781679
_check_option('hemi', hemi, ['lh', 'rh', 'split', 'both'])
16791680

16801681
time_label, times = _handle_time(time_label, time_unit, stc.times)

mne/viz/_brain/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010
# License: Simplified BSD
1111

1212
from ._brain import _Brain
13+
from ._timeviewer import _TimeViewer
1314

1415
__all__ = ['_Brain']

mne/viz/_brain/_brain.py

+173-22
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,6 @@ class _Brain(object):
120120
+---------------------------+--------------+-----------------------+
121121
| labels_dict | ✓ | |
122122
+---------------------------+--------------+-----------------------+
123-
| overlays | ✓ | - |
124-
+---------------------------+--------------+-----------------------+
125123
| remove_data | ✓ | |
126124
+---------------------------+--------------+-----------------------+
127125
| remove_foci | ✓ | |
@@ -134,6 +132,8 @@ class _Brain(object):
134132
+---------------------------+--------------+-----------------------+
135133
| show_view | ✓ | - |
136134
+---------------------------+--------------+-----------------------+
135+
| TimeViewer | ✓ | ✓ |
136+
+---------------------------+--------------+-----------------------+
137137
138138
"""
139139

@@ -181,7 +181,6 @@ def __init__(self, subject_id, hemi, surf, title=None,
181181
self._subjects_dir = subjects_dir
182182
self._views = views
183183
self._n_times = None
184-
self._scalarbar = False
185184
# for now only one color bar can be added
186185
# since it is the same for all figures
187186
self._colorbar_added = False
@@ -336,6 +335,7 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None,
336335

337336
hemi = self._check_hemi(hemi)
338337
array = np.asarray(array)
338+
self._data['array'] = array
339339

340340
# Create time array and add label if > 1D
341341
if array.ndim <= 1:
@@ -393,8 +393,10 @@ def time_label(x):
393393
self._data['initial_time'] = initial_time
394394
self._data['time_label'] = time_label
395395
self._data['time_idx'] = time_idx
396+
self._data['transparent'] = transparent
396397
# data specific for a hemi
397398
self._data[hemi + '_array'] = array
399+
self._data[hemi + '_vertices'] = vertices
398400

399401
self._data['alpha'] = alpha
400402
self._data['colormap'] = colormap
@@ -419,7 +421,7 @@ def time_label(x):
419421
dt_max = fmax
420422
dt_min = fmin if center is None else -1 * fmax
421423

422-
ctable = self.update_lut(transparent=transparent)
424+
ctable = self.update_lut()
423425

424426
for ri, v in enumerate(self._views):
425427
views_dict = lh_views_dict if hemi == 'lh' else rh_views_dict
@@ -428,22 +430,33 @@ def time_label(x):
428430
else:
429431
ci = 0 if hemi == 'lh' else 1
430432
self._renderer.subplot(ri, ci)
431-
mesh = self._renderer.mesh(x=self.geo[hemi].coords[:, 0],
432-
y=self.geo[hemi].coords[:, 1],
433-
z=self.geo[hemi].coords[:, 2],
434-
triangles=self.geo[hemi].faces,
435-
color=None,
436-
colormap=ctable,
437-
vmin=dt_min,
438-
vmax=dt_max,
439-
scalars=act_data)
433+
mesh_data = self._renderer.mesh(
434+
x=self.geo[hemi].coords[:, 0],
435+
y=self.geo[hemi].coords[:, 1],
436+
z=self.geo[hemi].coords[:, 2],
437+
triangles=self.geo[hemi].faces,
438+
color=None,
439+
colormap=ctable,
440+
vmin=dt_min,
441+
vmax=dt_max,
442+
scalars=act_data
443+
)
444+
if isinstance(mesh_data, tuple):
445+
actor, mesh = mesh_data
446+
else:
447+
actor, mesh = mesh_data, None
448+
self._data[hemi + '_actor'] = actor
449+
self._data[hemi + '_mesh'] = mesh
440450
if array.ndim >= 2 and callable(time_label):
441-
self._renderer.text2d(x_window=0.95, y_window=y_txt,
442-
size=time_label_size,
443-
text=time_label(time[time_idx]),
444-
justification='right')
451+
time_actor = self._renderer.text2d(
452+
x_window=0.95, y_window=y_txt,
453+
size=time_label_size,
454+
text=time_label(time[time_idx]),
455+
justification='right'
456+
)
457+
self._data[hemi + '_time_actor'] = time_actor
445458
if colorbar and not self._colorbar_added:
446-
self._renderer.scalarbar(source=mesh, n_labels=8,
459+
self._renderer.scalarbar(source=actor, n_labels=8,
447460
bgcolor=(0.5, 0.5, 0.5))
448461
self._colorbar_added = True
449462
self._renderer.set_camera(azimuth=views_dict[v].azim,
@@ -763,7 +776,7 @@ def screenshot(self, mode='rgb'):
763776
"""
764777
return self._renderer.screenshot(mode)
765778

766-
def update_lut(self, fmin=None, fmid=None, fmax=None, transparent=True):
779+
def update_lut(self, fmin=None, fmid=None, fmax=None):
767780
u"""Update color map.
768781
769782
Parameters
@@ -779,6 +792,7 @@ def update_lut(self, fmin=None, fmid=None, fmax=None, transparent=True):
779792
alpha = self._data['alpha']
780793
center = self._data['center']
781794
colormap = self._data['colormap']
795+
transparent = self._data['transparent']
782796
fmin = self._data['fmin'] if fmin is None else fmin
783797
fmid = self._data['fmid'] if fmid is None else fmid
784798
fmax = self._data['fmax'] if fmax is None else fmax
@@ -789,9 +803,146 @@ def update_lut(self, fmin=None, fmid=None, fmax=None, transparent=True):
789803

790804
return self._data['ctable']
791805

792-
@property
793-
def overlays(self):
794-
return self._overlays
806+
def set_data_smoothing(self, n_steps):
807+
"""Set the number of smoothing steps.
808+
809+
Parameters
810+
----------
811+
n_steps : int
812+
Number of smoothing steps
813+
"""
814+
from ..backends._pyvista import _set_mesh_scalars
815+
for hemi in ['lh', 'rh']:
816+
pd = self._data.get(hemi + '_mesh')
817+
if pd is not None:
818+
array = self._data[hemi + '_array']
819+
vertices = self._data[hemi + '_vertices']
820+
if pd is not None:
821+
time_idx = self._data['time_idx']
822+
if self._data['array'].ndim == 1:
823+
act_data = array
824+
elif self._data['array'].ndim == 2:
825+
act_data = array[:, time_idx]
826+
827+
adj_mat = mesh_edges(self.geo[hemi].faces)
828+
smooth_mat = smoothing_matrix(vertices,
829+
adj_mat, int(n_steps),
830+
verbose=False)
831+
act_data = smooth_mat.dot(act_data)
832+
_set_mesh_scalars(pd, act_data, 'Data')
833+
self._data[hemi + '_smooth_mat'] = smooth_mat
834+
835+
def set_time_point(self, time_idx):
836+
"""Set the time point shown."""
837+
from ..backends._pyvista import _set_mesh_scalars
838+
time_idx = int(time_idx)
839+
for hemi in ['lh', 'rh']:
840+
pd = self._data.get(hemi + '_mesh')
841+
if pd is not None:
842+
array = self._data[hemi + '_array']
843+
time = self._data['time']
844+
time_label = self._data['time_label']
845+
time_actor = self._data.get(hemi + '_time_actor')
846+
if array.ndim == 1:
847+
continue # skip data without time axis
848+
# interpolation
849+
if array.ndim == 2:
850+
act_data = array
851+
852+
if isinstance(time_idx, int):
853+
act_data = act_data[:, time_idx]
854+
855+
smooth_mat = self._data[hemi + '_smooth_mat']
856+
if smooth_mat is not None:
857+
act_data = smooth_mat.dot(act_data)
858+
_set_mesh_scalars(pd, act_data, 'Data')
859+
if callable(time_label) and time_actor is not None:
860+
time_actor.SetInput(time_label(time[time_idx]))
861+
self._data['time_idx'] = time_idx
862+
863+
def update_fmax(self, fmax):
864+
"""Set the colorbar max point."""
865+
from ..backends._pyvista import _set_colormap_range
866+
if fmax > self._data['fmid']:
867+
ctable = self.update_lut(fmax=fmax)
868+
ctable = (ctable * 255).astype(np.uint8)
869+
center = self._data['center']
870+
for hemi in ['lh', 'rh']:
871+
actor = self._data.get(hemi + '_actor')
872+
if actor is not None:
873+
fmin = self._data['fmin']
874+
center = self._data['center']
875+
dt_max = fmax
876+
dt_min = fmin if center is None else -1 * fmax
877+
rng = [dt_min, dt_max]
878+
if self._colorbar_added:
879+
scalar_bar = self._renderer.plotter.scalar_bar
880+
else:
881+
scalar_bar = None
882+
_set_colormap_range(actor, ctable, scalar_bar, rng)
883+
self._data['fmax'] = fmax
884+
self._data['ctable'] = ctable
885+
886+
def update_fmid(self, fmid):
887+
"""Set the colorbar mid point."""
888+
from ..backends._pyvista import _set_colormap_range
889+
if self._data['fmin'] < fmid < self._data['fmax']:
890+
ctable = self.update_lut(fmid=fmid)
891+
ctable = (ctable * 255).astype(np.uint8)
892+
for hemi in ['lh', 'rh']:
893+
actor = self._data.get(hemi + '_actor')
894+
if actor is not None:
895+
if self._colorbar_added:
896+
scalar_bar = self._renderer.plotter.scalar_bar
897+
else:
898+
scalar_bar = None
899+
_set_colormap_range(actor, ctable, scalar_bar)
900+
self._data['fmid'] = fmid
901+
self._data['ctable'] = ctable
902+
903+
def update_fmin(self, fmin):
904+
"""Set the colorbar min point."""
905+
from ..backends._pyvista import _set_colormap_range
906+
if fmin < self._data['fmid']:
907+
ctable = self.update_lut(fmin=fmin)
908+
ctable = (ctable * 255).astype(np.uint8)
909+
for hemi in ['lh', 'rh']:
910+
actor = self._data.get(hemi + '_actor')
911+
if actor is not None:
912+
fmax = self._data['fmax']
913+
center = self._data['center']
914+
dt_max = fmax
915+
dt_min = fmin if center is None else -1 * fmax
916+
rng = [dt_min, dt_max]
917+
if self._colorbar_added:
918+
scalar_bar = self._renderer.plotter.scalar_bar
919+
else:
920+
scalar_bar = None
921+
_set_colormap_range(actor, ctable, scalar_bar, rng)
922+
self._data['fmin'] = fmin
923+
self._data['ctable'] = ctable
924+
925+
def update_fscale(self, fscale):
926+
"""Scale the colorbar points."""
927+
from ..backends._pyvista import _set_colormap_range
928+
fmin = self._data['fmin'] * fscale
929+
fmid = self._data['fmid'] * fscale
930+
fmax = self._data['fmax'] * fscale
931+
ctable = self.update_lut(fmin=fmin, fmid=fmid, fmax=fmax)
932+
ctable = (ctable * 255).astype(np.uint8)
933+
for hemi in ['lh', 'rh']:
934+
actor = self._data.get(hemi + '_actor')
935+
if actor is not None:
936+
center = self._data['center']
937+
dt_max = fmax
938+
dt_min = fmin if center is None else -1 * fmax
939+
rng = [dt_min, dt_max]
940+
if self._colorbar_added:
941+
scalar_bar = self._renderer.plotter.scalar_bar
942+
else:
943+
scalar_bar = None
944+
_set_colormap_range(actor, ctable, scalar_bar, rng)
945+
self._data['ctable'] = ctable
795946

796947
@property
797948
def data(self):

0 commit comments

Comments
 (0)