Skip to content

Commit 68a7757

Browse files
authored
MRG, ENH: Speed up forward computations with Numba (mne-tools#7133)
* ENH: Speed up forwards with Numba * FIX: Order * FIX: Missing mult * FIX: Refactor and add env option not to use numba (at import) * DOC: whats_new.rst * ENH: Faster cHPI * FIX: Missed one
1 parent a41ca39 commit 68a7757

File tree

8 files changed

+185
-123
lines changed

8 files changed

+185
-123
lines changed

doc/changes/latest.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ Changelog
6767

6868
- Speed up :func:`mne.beamformer.make_lcmv` and :func:`mne.beamformer.make_dics` calculations by vectorizing linear algebra calls by `Dmitrii Altukhov`_ and `Eric Larson`_
6969

70+
- Speed up :func:`mne.make_forward_solution` using Numba, by `Eric Larson`_
71+
7072
- For KIT systems without built-in layout, :func:`mne.channels.find_layout` now falls back on an automatically generated layout, by `Christian Brodbeck`_
7173

7274
- :meth:`mne.Epochs.plot` now takes a ``epochs_colors`` parameter to color specific epoch segments by `Mainak Jas`_

mne/chpi.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@
4545
from .forward import (_magnetic_dipole_field_vec, _create_meg_coils,
4646
_concatenate_coils)
4747
from .cov import make_ad_hoc_cov, compute_whitener
48+
from .fixes import jit
4849
from .transforms import (apply_trans, invert_transform, _angle_between_quats,
4950
quat_to_rot, rot_to_quat)
5051
from .utils import (verbose, logger, use_log_level, _check_fname, warn,
51-
_check_option, _svd_lwork, _repeated_svd,
52-
ddot, dgemm, dgemv)
52+
_check_option)
5353

5454
# Eventually we should add:
5555
# hpicons
@@ -380,8 +380,7 @@ def _get_hpi_initial_fit(info, adjust=False, verbose=None):
380380
return hpi_rrs
381381

382382

383-
def _magnetic_dipole_objective(x, B, B2, coils, scale, method, too_close,
384-
lwork):
383+
def _magnetic_dipole_objective(x, B, B2, coils, scale, method, too_close):
385384
"""Project data onto right eigenvectors of whitened forward."""
386385
if method == 'forward':
387386
fwd = _magnetic_dipole_field_vec(x[np.newaxis, :], coils, too_close)
@@ -390,27 +389,32 @@ def _magnetic_dipole_objective(x, B, B2, coils, scale, method, too_close,
390389
# Eventually we can try incorporating external bases here, which
391390
# is why the :3 is on the SVD below
392391
fwd = _sss_basis(dict(origin=x, int_order=1, ext_order=0), coils).T
392+
return _magnetic_dipole_delta(fwd, scale, B, B2)
393+
394+
395+
@jit()
396+
def _magnetic_dipole_delta(fwd, scale, B, B2):
393397
# Here we use .T to get scale to Fortran order, which speeds things up
394-
fwd = dgemm(alpha=1., a=fwd, b=scale.T) # np.dot(fwd, scale.T)
395-
one = _repeated_svd(fwd, lwork, overwrite_a=True)[2]
396-
one = dgemv(alpha=1, a=one, x=B)
397-
Bm2 = ddot(one, one)
398+
fwd = np.dot(fwd, scale.T)
399+
one = np.linalg.svd(fwd, full_matrices=False)[2]
400+
one = np.dot(one, B)
401+
Bm2 = np.dot(one, one)
398402
return B2 - Bm2
399403

400404

401405
def _fit_magnetic_dipole(B_orig, x0, coils, scale, method, too_close):
402406
"""Fit a single bit of data (x0 = pos)."""
403407
from scipy.optimize import fmin_cobyla
404-
B = dgemv(alpha=1, a=scale, x=B_orig) # np.dot(scale, B_orig)
405-
B2 = ddot(B, B) # np.dot(B, B)
406-
lwork = _svd_lwork((3, B_orig.shape[0]))
408+
B = np.dot(scale, B_orig)
409+
B2 = np.dot(B, B)
407410
objective = partial(_magnetic_dipole_objective, B=B, B2=B2,
408411
coils=coils, scale=scale, method=method,
409-
too_close=too_close, lwork=lwork)
410-
x = fmin_cobyla(objective, x0, (), rhobeg=1e-4, rhoend=1e-5, disp=False)
412+
too_close=too_close)
413+
x = fmin_cobyla(objective, x0, (), rhobeg=1e-3, rhoend=1e-5, disp=False)
411414
return x, 1. - objective(x) / B2
412415

413416

417+
@jit()
414418
def _chpi_objective(x, coil_dev_rrs, coil_head_rrs):
415419
"""Compute objective function."""
416420
d = np.dot(coil_dev_rrs, quat_to_rot(x[:3]).T)
@@ -428,7 +432,8 @@ def _unit_quat_constraint(x):
428432
def _fit_chpi_quat(coil_dev_rrs, coil_head_rrs, x0):
429433
"""Fit rotation and translation (quaternion) parameters for cHPI coils."""
430434
from scipy.optimize import fmin_cobyla
431-
denom = np.sum((coil_head_rrs - np.mean(coil_head_rrs, axis=0)) ** 2)
435+
denom = np.linalg.norm(coil_head_rrs - np.mean(coil_head_rrs, axis=0))
436+
denom *= denom
432437
objective = partial(_chpi_objective, coil_dev_rrs=coil_dev_rrs,
433438
coil_head_rrs=coil_head_rrs)
434439
x0 = x0.copy()

mne/fixes.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import inspect
1616
from distutils.version import LooseVersion
1717
from math import log
18+
import os
1819
from pathlib import Path
1920
import warnings
2021

@@ -1195,14 +1196,25 @@ def jit(nopython=True, nogil=True, fastmath=True, cache=True,
11951196
return numba.jit(nopython=nopython, nogil=nogil, fastmath=fastmath,
11961197
cache=cache, **kwargs)
11971198
except ImportError:
1199+
has_numba = False
1200+
else:
1201+
has_numba = (os.getenv('MNE_USE_NUMBA', 'true').lower() == 'true')
1202+
1203+
1204+
if not has_numba:
11981205
def jit(**kwargs): # noqa
11991206
def _jit(func):
12001207
return func
12011208
return _jit
12021209
prange = range
1203-
has_numba = False
1210+
bincount = np.bincount
12041211
else:
1205-
has_numba = True
1212+
@jit()
1213+
def bincount(x, weights, minlength): # noqa: D103
1214+
out = np.zeros(minlength)
1215+
for idx, w in zip(x, weights):
1216+
out[idx] += w
1217+
return out
12061218

12071219

12081220
###############################################################################

0 commit comments

Comments
 (0)