Skip to content

Commit d78f2e9

Browse files
committed
Merge remote-tracking branch 'upstream/main' into openmeeg
* upstream/main: BUG: Spectrum deprecation cleanup [circle deploy] (mne-tools#11115) Add API entry list and map (mne-tools#10999) Add legacy decorator (mne-tools#11097) [ENH, MRG] Add time-frequency epoch source estimation (mne-tools#11095)
2 parents a5592b9 + d587b89 commit d78f2e9

File tree

15 files changed

+350
-103
lines changed

15 files changed

+350
-103
lines changed

doc/changes/latest.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Enhancements
3838
- Allow :func:`mne.beamformer.make_dics` to take ``pick_ori='vector'`` to compute vector source estimates (:gh:`19080` by `Alex Rockhill`_)
3939
- Add ``units`` parameter to :func:`mne.io.read_raw_edf` in case units are missing from the file (:gh:`11099` by `Alex Gramfort`_)
4040
- Add ``on_missing`` functionality to all of our classes that have a ``drop_channels`` method, to control what happens when channel names are not in the object (:gh:`11077` by `Andrew Quinn`_)
41+
- Add :func:`mne.minimum_norm.apply_inverse_tfr_epochs` to apply inverse methods to time-frequency resolved epochs (:gh:`11095` by `Alex Rockhill`_)
4142

4243
Bugs
4344
~~~~

doc/conf.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,37 @@ def __call__(self, gallery_conf, fname, when):
468468
'matplotlib_animations': True,
469469
'compress_images': compress_images,
470470
'filename_pattern': '^((?!sgskip).)*$',
471+
'api_usage_ignore': (
472+
'('
473+
'.*__.*__|' # built-ins
474+
'.*Base.*|.*Array.*|mne.Vector.*|mne.Mixed.*|mne.Vol.*|' # inherited
475+
'mne.coreg.Coregistration.*|' # GUI
476+
# common
477+
'.*utils.*|.*verbose()|.*copy()|.*update()|.*save()|'
478+
'.*get_data()|'
479+
# mixins
480+
'.*add_channels()|.*add_reference_channels()|'
481+
'.*anonymize()|.*apply_baseline()|.*apply_function()|'
482+
'.*apply_hilbert()|.*as_type()|.*decimate()|'
483+
'.*drop()|.*drop_channels()|.*drop_log_stats()|'
484+
'.*export()|.*get_channel_types()|'
485+
'.*get_montage()|.*interpolate_bads()|.*next()|'
486+
'.*pick()|.*pick_channels()|.*pick_types()|'
487+
'.*plot_sensors()|.*rename_channels()|'
488+
'.*reorder_channels()|.*savgol_filter()|'
489+
'.*set_eeg_reference()|.*set_channel_types()|'
490+
'.*set_meas_date()|.*set_montage()|.*shift_time()|'
491+
'.*time_as_index()|.*to_data_frame()|'
492+
# dictionary inherited
493+
'.*clear()|.*fromkeys()|.*get()|.*items()|'
494+
'.*keys()|.*pop()|.*popitem()|.*setdefault()|'
495+
'.*values()|'
496+
# sklearn inherited
497+
'.*apply()|.*decision_function()|.*fit()|'
498+
'.*fit_transform()|.*get_params()|.*predict()|'
499+
'.*predict_proba()|.*set_params()|.*transform()|'
500+
# I/O, also related to mixins
501+
'.*.remove.*|.*.write.*)')
471502
}
472503
# Files were renamed from plot_* with:
473504
# find . -type f -name 'plot_*.py' -exec sh -c 'x="{}"; xn=`basename "${x}"`; git mv "$x" `dirname "${x}"`/${xn:5}' \; # noqa

doc/inverse.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Inverse Solutions
1818
apply_inverse_cov
1919
apply_inverse_epochs
2020
apply_inverse_raw
21+
apply_inverse_tfr_epochs
2122
compute_source_psd
2223
compute_source_psd_epochs
2324
compute_rank_inverse

examples/visualization/topo_customized.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import mne
2727
from mne.viz import iter_topography
2828
from mne import io
29-
from mne.time_frequency import psd_welch
3029
from mne.datasets import sample
3130

3231
print(__doc__)
@@ -42,8 +41,9 @@
4241
tmin, tmax = 0, 120 # use the first 120s of data
4342
fmin, fmax = 2, 20 # look at frequencies between 2 and 20Hz
4443
n_fft = 2048 # the FFT size (n_fft). Ideally a power of 2
45-
psds, freqs = psd_welch(raw, picks=picks, tmin=tmin, tmax=tmax,
46-
fmin=fmin, fmax=fmax)
44+
spectrum = raw.compute_psd(
45+
picks=picks, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax)
46+
psds, freqs = spectrum.get_data(exclude=(), return_freqs=True)
4747
psds = 20 * np.log10(psds) # scale to dB
4848

4949

mne/minimum_norm/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
from .inverse import (InverseOperator, read_inverse_operator, apply_inverse,
44
apply_inverse_raw, make_inverse_operator,
5-
apply_inverse_epochs, write_inverse_operator,
6-
compute_rank_inverse, prepare_inverse_operator,
7-
estimate_snr, apply_inverse_cov, INVERSE_METHODS)
5+
apply_inverse_epochs, apply_inverse_tfr_epochs,
6+
write_inverse_operator, compute_rank_inverse,
7+
prepare_inverse_operator, estimate_snr,
8+
apply_inverse_cov, INVERSE_METHODS)
89
from .time_frequency import (source_band_induced_power, source_induced_power,
910
compute_source_psd, compute_source_psd_epochs)
1011
from .resolution_matrix import (make_inverse_resolution_matrix,

mne/minimum_norm/inverse.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ..io.pick import channel_type, pick_info, pick_types, pick_channels
2828
from ..cov import (compute_whitener, _read_cov, _write_cov, Covariance,
2929
prepare_noise_cov)
30-
from ..epochs import BaseEpochs
30+
from ..epochs import BaseEpochs, EpochsArray
3131
from ..evoked import EvokedArray, Evoked
3232
from ..forward import (compute_depth_prior, _read_forward_meas_info,
3333
is_fixed_orient, compute_orient_prior,
@@ -883,6 +883,8 @@ def apply_inverse(evoked, inverse_operator, lambda2=1. / 9., method="dSPM",
883883
--------
884884
apply_inverse_raw : Apply inverse operator to raw object.
885885
apply_inverse_epochs : Apply inverse operator to epochs object.
886+
apply_inverse_tfr_epochs : Apply inverse operator to epochs tfr object.
887+
apply_inverse_cov : Apply inverse operator to covariance object.
886888
887889
Notes
888890
-----
@@ -1067,8 +1069,10 @@ def apply_inverse_raw(raw, inverse_operator, lambda2, method="dSPM",
10671069
10681070
See Also
10691071
--------
1070-
apply_inverse_epochs : Apply inverse operator to epochs object.
10711072
apply_inverse : Apply inverse operator to evoked object.
1073+
apply_inverse_epochs : Apply inverse operator to epochs object.
1074+
apply_inverse_tfr_epochs : Apply inverse operator to epochs tfr object.
1075+
apply_inverse_cov : Apply inverse operator to covariance object.
10721076
"""
10731077
_validate_type(raw, BaseRaw, 'raw')
10741078
_check_reference(raw, inverse_operator['info']['ch_names'])
@@ -1257,13 +1261,15 @@ def apply_inverse_epochs(epochs, inverse_operator, lambda2, method="dSPM",
12571261
12581262
Returns
12591263
-------
1260-
stc : list of (SourceEstimate | VectorSourceEstimate | VolSourceEstimate)
1264+
stcs : list of (SourceEstimate | VectorSourceEstimate | VolSourceEstimate)
12611265
The source estimates for all epochs.
12621266
12631267
See Also
12641268
--------
12651269
apply_inverse_raw : Apply inverse operator to raw object.
12661270
apply_inverse : Apply inverse operator to evoked object.
1271+
apply_inverse_tfr_epochs : Apply inverse operator to epochs tfr object.
1272+
apply_inverse_cov : Apply inverse operator to a covariance object.
12671273
"""
12681274
stcs = _apply_inverse_epochs_gen(
12691275
epochs, inverse_operator, lambda2, method=method, label=label,
@@ -1277,6 +1283,85 @@ def apply_inverse_epochs(epochs, inverse_operator, lambda2, method="dSPM",
12771283
return stcs
12781284

12791285

1286+
def _apply_inverse_tfr_epochs_gen(epochs_tfr, inverse_operator, lambda2,
1287+
method, label, nave, pick_ori, prepared,
1288+
method_params, use_cps):
1289+
for freq_idx in range(epochs_tfr.freqs.size):
1290+
epochs = EpochsArray(epochs_tfr.data[:, :, freq_idx, :],
1291+
epochs_tfr.info, events=epochs_tfr.events,
1292+
tmin=epochs_tfr.tmin)
1293+
this_inverse_operator = inverse_operator[freq_idx] if \
1294+
isinstance(inverse_operator, (list, tuple)) else inverse_operator
1295+
stcs = _apply_inverse_epochs_gen(
1296+
epochs, this_inverse_operator, lambda2, method=method,
1297+
label=label, nave=nave, pick_ori=pick_ori, prepared=prepared,
1298+
method_params=method_params, use_cps=use_cps)
1299+
yield stcs
1300+
1301+
1302+
@verbose
1303+
def apply_inverse_tfr_epochs(epochs_tfr, inverse_operator, lambda2,
1304+
method="dSPM", label=None, nave=1, pick_ori=None,
1305+
return_generator=False, prepared=False,
1306+
method_params=None, use_cps=True, verbose=None):
1307+
"""Apply inverse operator to EpochsTFR.
1308+
1309+
Parameters
1310+
----------
1311+
epochs_tfr : EpochsTFR object
1312+
Single trial, phase-amplitude (complex-valued), time-frequency epochs.
1313+
inverse_operator : list of dict | dict
1314+
The inverse operator for each frequency or a single inverse operator
1315+
to use for all frequencies.
1316+
lambda2 : float
1317+
The regularization parameter.
1318+
method : "MNE" | "dSPM" | "sLORETA" | "eLORETA"
1319+
Use minimum norm, dSPM (default), sLORETA, or eLORETA.
1320+
label : Label | None
1321+
Restricts the source estimates to a given label. If None,
1322+
source estimates will be computed for the entire source space.
1323+
nave : int
1324+
Number of averages used to regularize the solution.
1325+
Set to 1 on single Epoch by default.
1326+
%(pick_ori)s
1327+
return_generator : bool
1328+
Return a generator object instead of a list. This allows iterating
1329+
over the stcs without having to keep them all in memory.
1330+
prepared : bool
1331+
If True, do not call :func:`prepare_inverse_operator`.
1332+
method_params : dict | None
1333+
Additional options for eLORETA. See Notes of :func:`apply_inverse`.
1334+
%(use_cps_restricted)s
1335+
%(verbose)s
1336+
1337+
Returns
1338+
-------
1339+
stcs : list of list of (SourceEstimate | VectorSourceEstimate | VolSourceEstimate)
1340+
The source estimates for all frequencies (outside list) and for
1341+
all epochs (inside list).
1342+
1343+
See Also
1344+
--------
1345+
apply_inverse_raw : Apply inverse operator to raw object.
1346+
apply_inverse : Apply inverse operator to evoked object.
1347+
apply_inverse_epochs : Apply inverse operator to epochs object.
1348+
apply_inverse_cov : Apply inverse operator to a covariance object.
1349+
""" # noqa E501
1350+
from ..time_frequency.tfr import _check_tfr_complex
1351+
_check_tfr_complex(epochs_tfr)
1352+
if isinstance(inverse_operator, (list, tuple)) and \
1353+
len(inverse_operator) != epochs_tfr.freqs.size:
1354+
raise ValueError(f'Expected {epochs_tfr.freqs.size} inverse '
1355+
f'operators, got {len(inverse_operator)}')
1356+
stcs = _apply_inverse_tfr_epochs_gen(
1357+
epochs_tfr, inverse_operator, lambda2,
1358+
method, label, nave, pick_ori, prepared,
1359+
method_params, use_cps)
1360+
if not return_generator:
1361+
stcs = [[stc for stc in tfr_stcs] for tfr_stcs in stcs]
1362+
return stcs
1363+
1364+
12801365
@verbose
12811366
def apply_inverse_cov(cov, info, inverse_operator, nave=1, lambda2=1 / 9,
12821367
method="dSPM", pick_ori=None, prepared=False,
@@ -1319,6 +1404,7 @@ def apply_inverse_cov(cov, info, inverse_operator, nave=1, lambda2=1 / 9,
13191404
apply_inverse : Apply inverse operator to evoked object.
13201405
apply_inverse_raw : Apply inverse operator to raw object.
13211406
apply_inverse_epochs : Apply inverse operator to epochs object.
1407+
apply_inverse_tfr_epochs : Apply inverse operator to epochs tfr object.
13221408
13231409
Notes
13241410
-----

mne/minimum_norm/tests/test_inverse.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
convert_forward_solution, Covariance, combine_evoked,
2626
SourceEstimate, make_sphere_model, make_ad_hoc_cov,
2727
pick_channels_forward, compute_raw_covariance)
28-
from mne.io import read_raw_fif
28+
from mne.io import read_raw_fif, read_info
2929
from mne.minimum_norm import (apply_inverse, read_inverse_operator,
3030
apply_inverse_raw, apply_inverse_epochs,
31+
apply_inverse_tfr_epochs,
3132
make_inverse_operator, apply_inverse_cov,
3233
write_inverse_operator, prepare_inverse_operator,
3334
compute_rank_inverse, INVERSE_METHODS)
35+
from mne.time_frequency import EpochsTFR
3436
from mne.utils import catch_logging, _record_warnings
3537

3638
test_path = testing.data_path(download=False)
@@ -1097,6 +1099,51 @@ def test_apply_mne_inverse_epochs():
10971099
inverse_operator, 1.)
10981100

10991101

1102+
@pytest.mark.slowtest
1103+
@testing.requires_testing_data
1104+
@pytest.mark.parametrize('return_generator', (True, False))
1105+
def test_apply_inverse_tfr(return_generator):
1106+
"""Test applying an inverse to time-frequency data."""
1107+
rng = np.random.default_rng(11)
1108+
n_epochs = 4
1109+
info = read_info(fname_raw)
1110+
inverse_operator = read_inverse_operator(fname_full)
1111+
freqs = np.arange(8, 10)
1112+
sfreq = info['sfreq']
1113+
times = np.arange(sfreq) / sfreq # make epochs 1s long
1114+
data = rng.random((n_epochs, len(info.ch_names), freqs.size, times.size))
1115+
data = data + 1j * data # make complex to simulate amplitude + phase
1116+
epochs_tfr = EpochsTFR(info, data, times=times, freqs=freqs)
1117+
epochs_tfr.apply_baseline((0, 0.5))
1118+
pick_ori = 'vector'
1119+
1120+
with pytest.raises(ValueError, match='Expected 2 inverse operators, '
1121+
'got 3'):
1122+
apply_inverse_tfr_epochs(epochs_tfr, [inverse_operator] * 3, lambda2)
1123+
1124+
# test epochs
1125+
stcs = apply_inverse_tfr_epochs(
1126+
epochs_tfr, inverse_operator, lambda2, "dSPM", pick_ori=pick_ori,
1127+
return_generator=return_generator)
1128+
1129+
n_orient = 3 if pick_ori == 'vector' else 1
1130+
if return_generator:
1131+
stcs = [[s for s in these_stcs] for these_stcs in stcs]
1132+
assert_allclose(stcs[0][0].times, times)
1133+
assert len(stcs) == freqs.size
1134+
assert all([len(s) == len(epochs_tfr) for s in stcs])
1135+
assert all([s.data.shape == (inverse_operator['nsource'],
1136+
n_orient, times.size)
1137+
for these_stcs in stcs for s in these_stcs])
1138+
1139+
evoked = EvokedArray(data.mean(axis=(0, 2)), info, epochs_tfr.tmin)
1140+
stc = apply_inverse(
1141+
evoked, inverse_operator, lambda2, "dSPM", pick_ori=pick_ori)
1142+
tfr_stc_data = np.array([[stc.data for stc in tfr_stcs]
1143+
for tfr_stcs in stcs])
1144+
assert_allclose(stc.data, tfr_stc_data.mean(axis=(0, 1)))
1145+
1146+
11001147
def test_make_inverse_operator_bads(evoked, noise_cov):
11011148
"""Test MNE inverse computation given a mismatch of bad channels."""
11021149
fwd_op = read_forward_solution_meg(fname_fwd, surf_ori=True)

mne/time_frequency/psd.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,10 @@ def psd_welch(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None, n_fft=256,
254254
.. versionadded:: 0.12.0
255255
"""
256256
spectrum = inst.compute_psd(
257-
'welch', fmin, fmax, tmin, tmax, picks, proj, reject_by_annotation,
258-
n_jobs=n_jobs, verbose=verbose, n_fft=n_fft, n_overlap=n_overlap,
259-
n_per_seg=n_per_seg, average=average, window=window)
257+
'welch', fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, picks=picks,
258+
proj=proj, n_jobs=n_jobs, verbose=verbose, n_fft=n_fft,
259+
n_overlap=n_overlap, n_per_seg=n_per_seg, average=average,
260+
window=window)
260261
return spectrum.get_data(return_freqs=True)
261262

262263

@@ -323,8 +324,8 @@ def psd_multitaper(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None,
323324
.. footbibliography::
324325
"""
325326
spectrum = inst.compute_psd(
326-
'multitaper', fmin, fmax, tmin, tmax, picks, proj,
327-
reject_by_annotation, n_jobs=n_jobs, verbose=verbose,
328-
bandwidth=bandwidth, adaptive=adaptive, low_bias=low_bias,
329-
normalization=normalization)
327+
'multitaper', fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, picks=picks,
328+
proj=proj, reject_by_annotation=reject_by_annotation, n_jobs=n_jobs,
329+
verbose=verbose, bandwidth=bandwidth, adaptive=adaptive,
330+
low_bias=low_bias, normalization=normalization)
330331
return spectrum.get_data(return_freqs=True)

mne/time_frequency/tfr.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2658,3 +2658,9 @@ def _preproc_tfr_instance(tfr, picks, tmin, tmax, fmin, fmax, vmin, vmax, dB,
26582658
tfr.data = data
26592659

26602660
return tfr
2661+
2662+
2663+
def _check_tfr_complex(tfr, reason='source space estimation'):
2664+
"""Check that time-frequency epochs or average data is complex."""
2665+
if not np.iscomplexobj(tfr.data):
2666+
raise RuntimeError(f'Time-frequency data must be complex for {reason}')

mne/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
sys_info, _get_extra_data_path, _get_root_dir,
3030
_get_numpy_libs)
3131
from .docs import (copy_function_doc_to_method_doc, copy_doc, linkcode_resolve,
32-
open_docs, deprecated, fill_doc, deprecated_alias,
32+
open_docs, deprecated, fill_doc, deprecated_alias, legacy,
3333
copy_base_doc_to_subclass_doc, docdict as _docdict)
3434
from .fetching import _url_to_local_path
3535
from ._logging import (verbose, logger, set_log_level, set_log_file,

0 commit comments

Comments
 (0)