Skip to content

Commit 9555588

Browse files
authored
MAINT: Refactor forward solution code (#11103)
1 parent db22d67 commit 9555588

File tree

9 files changed

+287
-413
lines changed

9 files changed

+287
-413
lines changed

mne/dipole.py

Lines changed: 58 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -722,10 +722,11 @@ def _write_dipole_bdip(fname, dip):
722722
# #############################################################################
723723
# Fitting
724724

725-
def _dipole_forwards(fwd_data, whitener, rr, n_jobs=None):
725+
def _dipole_forwards(*, sensors, fwd_data, whitener, rr, n_jobs=None):
726726
"""Compute the forward solution and do other nice stuff."""
727-
B = _compute_forwards_meeg(rr, fwd_data, n_jobs, silent=True)
728-
B = np.concatenate(B, axis=1)
727+
B = _compute_forwards_meeg(
728+
rr, sensors=sensors, fwd_data=fwd_data, n_jobs=n_jobs, silent=True)
729+
B = np.concatenate(list(B.values()), axis=1)
729730
assert np.isfinite(B).all()
730731
B_orig = B.copy()
731732

@@ -763,11 +764,13 @@ def _make_guesses(surf, grid, exclude, mindist, n_jobs=None, verbose=None):
763764
return SourceSpaces([src])
764765

765766

766-
def _fit_eval(rd, B, B2, fwd_svd=None, fwd_data=None, whitener=None,
767-
lwork=None):
767+
def _fit_eval(rd, B, B2, *, sensors, fwd_data, whitener, lwork, fwd_svd):
768768
"""Calculate the residual sum of squares."""
769769
if fwd_svd is None:
770-
fwd = _dipole_forwards(fwd_data, whitener, rd[np.newaxis, :])[0]
770+
assert sensors is not None
771+
fwd = _dipole_forwards(
772+
sensors=sensors, fwd_data=fwd_data, whitener=whitener,
773+
rr=rd[np.newaxis, :])[0]
771774
uu, sing, vv = _repeated_svd(fwd, lwork, overwrite_a=True)
772775
else:
773776
uu, sing, vv = fwd_svd
@@ -791,7 +794,7 @@ def _dipole_gof(uu, sing, vv, B, B2):
791794
return gof, one
792795

793796

794-
def _fit_Q(fwd_data, whitener, B, B2, B_orig, rd, ori=None):
797+
def _fit_Q(*, sensors, fwd_data, whitener, B, B2, B_orig, rd, ori=None):
795798
"""Fit the dipole moment once the location is known."""
796799
from scipy import linalg
797800
if 'fwd' in fwd_data:
@@ -805,8 +808,9 @@ def _fit_Q(fwd_data, whitener, B, B2, B_orig, rd, ori=None):
805808
assert scales.shape == (3,)
806809
fwd_svd = fwd_data['fwd_svd'][0]
807810
else:
808-
fwd, fwd_orig, scales = _dipole_forwards(fwd_data, whitener,
809-
rd[np.newaxis, :])
811+
fwd, fwd_orig, scales = _dipole_forwards(
812+
sensors=sensors, fwd_data=fwd_data, whitener=whitener,
813+
rr=rd[np.newaxis, :])
810814
fwd_svd = None
811815
if ori is None:
812816
if fwd_svd is None:
@@ -830,15 +834,18 @@ def _fit_Q(fwd_data, whitener, B, B2, B_orig, rd, ori=None):
830834

831835

832836
def _fit_dipoles(fun, min_dist_to_inner_skull, data, times, guess_rrs,
833-
guess_data, fwd_data, whitener, ori, n_jobs, rank, rhoend):
837+
guess_data, *, sensors, fwd_data, whitener, ori, n_jobs,
838+
rank, rhoend):
834839
"""Fit a single dipole to the given whitened, projected data."""
835840
from scipy.optimize import fmin_cobyla
836841
parallel, p_fun, n_jobs = parallel_func(fun, n_jobs)
837842
# parallel over time points
838-
res = parallel(p_fun(min_dist_to_inner_skull, B, t, guess_rrs,
839-
guess_data, fwd_data, whitener,
840-
fmin_cobyla, ori, rank, rhoend)
841-
for B, t in zip(data.T, times))
843+
res = parallel(
844+
p_fun(
845+
min_dist_to_inner_skull, B, t, guess_rrs, guess_data,
846+
sensors=sensors, fwd_data=fwd_data, whitener=whitener,
847+
fmin_cobyla=fmin_cobyla, ori=ori, rank=rank, rhoend=rhoend)
848+
for B, t in zip(data.T, times))
842849
pos = np.array([r[0] for r in res])
843850
amp = np.array([r[1] for r in res])
844851
ori = np.array([r[2] for r in res])
@@ -948,7 +955,7 @@ def _simplex_minimize(p, ftol, stol, fun, max_eval=1000):
948955
'''
949956

950957

951-
def _fit_confidence(rd, Q, ori, whitener, fwd_data):
958+
def _fit_confidence(*, rd, Q, ori, whitener, fwd_data, sensors):
952959
# As describedd in the Xfit manual, confidence intervals can be calculated
953960
# by examining a linearization of model at the best-fitting location,
954961
# i.e. taking the Jacobian and using the whitener:
@@ -977,11 +984,15 @@ def _fit_confidence(rd, Q, ori, whitener, fwd_data):
977984
for delta in deltas:
978985
this_r = rd[np.newaxis] + delta * direction[ii]
979986
fwds.append(
980-
np.dot(Q, _dipole_forwards(fwd_data, whitener, this_r)[0]))
987+
np.dot(Q, _dipole_forwards(
988+
sensors=sensors, fwd_data=fwd_data,
989+
whitener=whitener, rr=this_r)[0]))
981990
J[:, ii] = np.diff(fwds, axis=0)[0] / np.diff(deltas)[0]
982991
# Get current (Q) deltas in the dipole directions
983992
deltas = np.array([-0.01, 0.01]) * np.linalg.norm(Q)
984-
this_fwd = _dipole_forwards(fwd_data, whitener, rd[np.newaxis])[0]
993+
this_fwd = _dipole_forwards(
994+
sensors=sensors, fwd_data=fwd_data, whitener=whitener,
995+
rr=rd[np.newaxis])[0]
985996
for ii in range(3):
986997
fwds = []
987998
for delta in deltas:
@@ -1031,8 +1042,8 @@ def _sphere_constraint(rd, r0, R_adj):
10311042

10321043

10331044
def _fit_dipole(min_dist_to_inner_skull, B_orig, t, guess_rrs,
1034-
guess_data, fwd_data, whitener, fmin_cobyla, ori, rank,
1035-
rhoend):
1045+
guess_data, *, sensors, fwd_data, whitener, fmin_cobyla,
1046+
ori, rank, rhoend):
10361047
"""Fit a single bit of data."""
10371048
B = np.dot(whitener, B_orig)
10381049

@@ -1053,12 +1064,14 @@ def _fit_dipole(min_dist_to_inner_skull, B_orig, t, guess_rrs,
10531064
warn('Zero field found for time %s' % t)
10541065
return np.zeros(3), 0, np.zeros(3), 0, B
10551066

1056-
idx = np.argmin([_fit_eval(guess_rrs[[fi], :], B, B2, fwd_svd)
1057-
for fi, fwd_svd in enumerate(guess_data['fwd_svd'])])
1067+
idx = np.argmin([
1068+
_fit_eval(guess_rrs[[fi], :], B, B2, fwd_svd=fwd_svd,
1069+
fwd_data=None, sensors=None, whitener=None, lwork=None)
1070+
for fi, fwd_svd in enumerate(guess_data['fwd_svd'])])
10581071
x0 = guess_rrs[idx]
10591072
lwork = _svd_lwork((3, B.shape[0]))
10601073
fun = partial(_fit_eval, B=B, B2=B2, fwd_data=fwd_data, whitener=whitener,
1061-
lwork=lwork)
1074+
lwork=lwork, sensors=sensors, fwd_svd=None)
10621075

10631076
# Tested minimizers:
10641077
# Simplex, BFGS, CG, COBYLA, L-BFGS-B, Powell, SLSQP, TNC
@@ -1074,14 +1087,17 @@ def _fit_dipole(min_dist_to_inner_skull, B_orig, t, guess_rrs,
10741087

10751088
# Compute the dipole moment at the final point
10761089
Q, gof, residual_noproj, n_comp = _fit_Q(
1077-
fwd_data, whitener, B, B2, B_orig, rd_final, ori=ori)
1090+
sensors=sensors, fwd_data=fwd_data, whitener=whitener, B=B, B2=B2,
1091+
B_orig=B_orig, rd=rd_final, ori=ori)
10781092
khi2 = (1 - gof) * B2
10791093
nfree = rank - n_comp
10801094
amp = np.sqrt(np.dot(Q, Q))
10811095
norm = 1. if amp == 0. else amp
10821096
ori = Q / norm
10831097

1084-
conf = _fit_confidence(rd_final, Q, ori, whitener, fwd_data)
1098+
conf = _fit_confidence(
1099+
sensors=sensors, rd=rd_final, Q=Q, ori=ori, whitener=whitener,
1100+
fwd_data=fwd_data)
10851101

10861102
msg = '---- Fitted : %7.1f ms' % (1000. * t)
10871103
if surf is not None:
@@ -1095,7 +1111,7 @@ def _fit_dipole(min_dist_to_inner_skull, B_orig, t, guess_rrs,
10951111

10961112

10971113
def _fit_dipole_fixed(min_dist_to_inner_skull, B_orig, t, guess_rrs,
1098-
guess_data, fwd_data, whitener,
1114+
guess_data, *, sensors, fwd_data, whitener,
10991115
fmin_cobyla, ori, rank, rhoend):
11001116
"""Fit a data using a fixed position."""
11011117
B = np.dot(whitener, B_orig)
@@ -1104,8 +1120,9 @@ def _fit_dipole_fixed(min_dist_to_inner_skull, B_orig, t, guess_rrs,
11041120
warn('Zero field found for time %s' % t)
11051121
return np.zeros(3), 0, np.zeros(3), 0, np.zeros(6)
11061122
# Compute the dipole moment
1107-
Q, gof, residual_noproj = _fit_Q(guess_data, whitener, B, B2, B_orig,
1108-
rd=None, ori=ori)[:3]
1123+
Q, gof, residual_noproj = _fit_Q(
1124+
fwd_data=guess_data, whitener=whitener, B=B, B2=B2, B_orig=B_orig,
1125+
sensors=sensors, rd=None, ori=ori)[:3]
11091126
if ori is None:
11101127
amp = np.sqrt(np.dot(Q, Q))
11111128
norm = 1. if amp == 0. else amp
@@ -1318,18 +1335,16 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=None,
13181335
# Forward model setup (setup_forward_model from setup.c)
13191336
ch_types = evoked.get_channel_types()
13201337

1321-
megcoils, compcoils, megnames, meg_info = [], [], [], None
1322-
eegels, eegnames = [], []
1338+
sensors = dict()
13231339
if 'grad' in ch_types or 'mag' in ch_types:
1324-
megcoils, compcoils, megnames, meg_info = \
1325-
_prep_meg_channels(info, exclude='bads',
1326-
accuracy=accuracy, verbose=verbose)
1340+
sensors['meg'] = _prep_meg_channels(
1341+
info, exclude='bads', accuracy=accuracy, verbose=verbose)
13271342
if 'eeg' in ch_types:
1328-
eegels, eegnames = _prep_eeg_channels(info, exclude='bads',
1329-
verbose=verbose)
1343+
sensors['eeg'] = _prep_eeg_channels(
1344+
info, exclude='bads', verbose=verbose)
13301345

13311346
# Ensure that MEG and/or EEG channels are present
1332-
if len(megcoils + eegels) == 0:
1347+
if len(sensors) == 0:
13331348
raise RuntimeError('No MEG or EEG channels found.')
13341349

13351350
# Whitener for the data
@@ -1378,14 +1393,13 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=None,
13781393
'skull boundary' % (-1000 * check,))
13791394

13801395
# C code computes guesses w/sphere model for speed, don't bother here
1381-
fwd_data = dict(coils_list=[megcoils, eegels], infos=[meg_info, None],
1382-
ccoils_list=[compcoils, None], coil_types=['meg', 'eeg'],
1383-
inner_skull=inner_skull)
1384-
# fwd_data['inner_skull'] in head frame, bem in mri, confusing...
1385-
_prep_field_computation(guess_src['rr'], bem, fwd_data, n_jobs,
1386-
verbose=False)
1396+
fwd_data = _prep_field_computation(
1397+
guess_src['rr'], sensors=sensors, bem=bem, n_jobs=n_jobs,
1398+
verbose=False)
1399+
fwd_data['inner_skull'] = inner_skull
13871400
guess_fwd, guess_fwd_orig, guess_fwd_scales = _dipole_forwards(
1388-
fwd_data, whitener, guess_src['rr'], n_jobs=fit_n_jobs)
1401+
sensors=sensors, fwd_data=fwd_data, whitener=whitener,
1402+
rr=guess_src['rr'], n_jobs=fit_n_jobs)
13891403
# decompose ahead of time
13901404
guess_fwd_svd = [linalg.svd(fwd, full_matrices=False)
13911405
for fwd in np.array_split(guess_fwd,
@@ -1403,7 +1417,8 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=None,
14031417
fun = _fit_dipole_fixed if fixed_position else _fit_dipole
14041418
out = _fit_dipoles(
14051419
fun, min_dist_to_inner_skull, data, times, guess_src['rr'],
1406-
guess_data, fwd_data, whitener, ori, n_jobs, rank, tol)
1420+
guess_data, sensors=sensors, fwd_data=fwd_data, whitener=whitener,
1421+
ori=ori, n_jobs=n_jobs, rank=rank, rhoend=tol)
14071422
assert len(out) == 8
14081423
if fixed_position and ori is not None:
14091424
# DipoleFixed

0 commit comments

Comments
 (0)