mne-tools / mne-python

MNE: Magnetoencephalography (MEG) and Electroencephalography (EEG) in Python
https://mne.tools
BSD 3-Clause "New" or "Revised" License
2.67k stars 1.31k forks source link

TimeFrequency Estimator modifies parameters in constructor #10971

Closed Dod12 closed 2 years ago

Dod12 commented 2 years ago

Describe the bug

The mne.decoding.TimeFrequency transformer modifies constructor arguments, violating scikit-learn guidance on estimators. This leads to a cloning error when using the function in a pipeline. I was able to resolve the issue by moving the _check_tfr_param call to the transform method, in line with other checks performed at that time. See the changes made to mne.decoding.time_frequency.py

Steps to reproduce

import mne
import numpy as np
from sklearn import pipeline, linear_model

tfr_data = np.ones((100, 10, 1000))

freqs = np.array([5.])

estimator = pipeline.make_pipeline(
    mne.decoding.TimeFrequency(freqs, 10, "morlet", freqs/5., output="power"),
    mne.decoding.Vectorizer(),
    linear_model.LogisticRegression(),
)

mne.decoding.cross_val_multiscore(estimator, tfr_data, np.random.binomial(1, 0.5, 100))

Expected results

Successful completion of cross validation.

Actual results

Traceback (most recent call last):
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/joblib/parallel.py", line 822, in dispatch_one_batch
    tasks = self._ready_batches.get(block=False)
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/queue.py", line 168, in get
    raise Empty
_queue.Empty
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3552, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-1fb399c6f0bb>", line 1, in <cell line: 1>
    runfile('error.py', wdir='/Users/daniel/Documents/Coding_Projects/GitHub/mne-python')
  File "/Users/daniel/Library/Application Support/JetBrains/Toolbox/apps/PyCharm-P/ch-0/221.6008.17/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 198, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/Users/daniel/Library/Application Support/JetBrains/Toolbox/apps/PyCharm-P/ch-0/221.6008.17/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "error.py", line 15, in <module>
    mne.decoding.cross_val_multiscore(estimator, tfr_data, np.random.binomial(1, 0.5, 100))
  File "<decorator-gen-447>", line 12, in cross_val_multiscore
  File "/Users/daniel/Documents/Coding_Projects/GitHub/mne-python/mne/decoding/base.py", line 435, in cross_val_multiscore
    scores = parallel(
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/joblib/parallel.py", line 1043, in __call__
    if self.dispatch_one_batch(iterator):
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/joblib/parallel.py", line 833, in dispatch_one_batch
    islice = list(itertools.islice(iterator, big_batch_size))
  File "/Users/daniel/Documents/Coding_Projects/GitHub/mne-python/mne/decoding/base.py", line 437, in <genexpr>
    estimator=clone(estimator), X=X, y=y, scorer=scorer, train=train,
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 87, in clone
    new_object_params[name] = clone(param, safe=False)
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in clone
    return estimator_type([clone(e, safe=safe) for e in estimator])
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in <listcomp>
    return estimator_type([clone(e, safe=safe) for e in estimator])
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in clone
    return estimator_type([clone(e, safe=safe) for e in estimator])
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in <listcomp>
    return estimator_type([clone(e, safe=safe) for e in estimator])
  File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 96, in clone
    raise RuntimeError(
RuntimeError: Cannot clone object TimeFrequency(None), as the constructor either does not set or modifies parameter n_cycles

Additional information

Platform:      macOS-11.6.6-x86_64-i386-64bit
Python:        3.10.5 | packaged by conda-forge | (main, Jun 14 2022, 07:09:13) [Clang 13.0.1 ]
Executable:    /Users/daniel/miniconda3/envs/mne-python/bin/python
CPU:           i386: 4 cores
Memory:        16.0 GB
mne:           0.23.4
numpy:         1.22.4 {blas=NO_ATLAS_INFO, lapack=lapack}
scipy:         1.8.1
matplotlib:    3.5.2 {backend=module://backend_interagg}
sklearn:       1.1.1
numba:         0.55.2
nibabel:       4.0.1
nilearn:       0.6.2
dipy:          1.5.0
cupy:          Not found
pandas:        1.4.3
mayavi:        4.8.0
pyvista:       0.35.2 {pyvistaqt=0.9.0, OpenGL 4.1 ATI-4.6.21 via AMD Radeon R9 M295X OpenGL Engine}
vtk:           
PyQt5:         5.12.3
welcome[bot] commented 2 years ago

Hello! šŸ‘‹ Thanks for opening your first issue here! ā¤ļø We will try to get back to you soon. šŸš“šŸ½ā€ā™‚ļø

larsoner commented 2 years ago

@Dod12 agreed this seems like a bug, would you be up for making a PR to fix it? The minimal example above is already a good start for a unit test!

Dod12 commented 2 years ago

@larsoner Sure, I'll work on the tests over the weekend.