nbara / python-meegkit

🔧🧠 MEEGkit: MEG & EEG processing toolkit in Python
https://nbara.github.io/python-meegkit/
BSD 3-Clause "New" or "Revised" License
182 stars 50 forks source link

Reduce memory footprint of `dss_line` #58

Closed eort closed 2 years ago

eort commented 2 years ago

Here the promised PR. I don't have a good estimate of the saved memory (as it never worked with the original code) but the gain is at least 50% (reduced from far more than 24gb to 15gb at the max

In a nutshell:

Generally, my motivation was quite selfish. I wanted to make the algorithm work for my data. So, while I obviously tried to not break anything, there is a chance that for other data things might not work anymore. Particularly, the edit in matrix.py I suggested, because I simply don't understand why the code was what was. The new version works for my purposes, but perhaps not for others.

If the PR is of any use to you, I can try to add a few tests.

Overall, the results looks really nice, I think: image

ps. I am not the most versatile github user, so sorry that this is a separate PR and not an upgrade to #57 (if this was desirable)

nbara commented 2 years ago

Thanks @eort. I have looked at the changes only quickly, but I can already tell you that the CI do not pass any more (you can check that locally by running make pep from meegkit's root directory)

I think only a subset of those changes are beneficial (the "inplace" recentring for instance). Some others I am less sure about (for instance when you compute X-X_filt 4 times in the dss_line() code).

I will try to post proper benchmarks (computation time and memory consumption) soon to get to the bottom of this.

eort commented 2 years ago

Yeah I wasn't sure about those ones either, but for me saving memory was more critical than saving time. In any case, looking forward to those benchmarks!

(Hope CI is happier now)

nbara commented 2 years ago

(Hope CI is happier now)

Bah, I meant run pytest --noplots to run the unit tests, which are broken (due to the changes in demean, apparently). Sorry!

I'll post the benchmarks shortly

nbara commented 2 years ago

Ok, I've run a memory profiler to assess changes.

Using simulated data with 100 channels and 1e6 time points @ 200 hz.

With your changes

Memory:

```bash Line # Mem usage Increment Occurences Line Contents ============================================================ 139 3058.7 MiB 3058.7 MiB 1 @profile 140 def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None, 141 show=False): 142 """Apply DSS to remove power line artifacts. 143 144 Implements the ZapLine algorithm described in [1]_. 145 146 Parameters 147 ---------- 148 X : data, shape=(n_samples, n_chans, n_trials) 149 Input data. 150 fline : float 151 Line frequency (normalized to sfreq, if ``sfreq`` == 1). 152 sfreq : float 153 Sampling frequency (default=1, which assymes ``fline`` is normalised). 154 nremove : int 155 Number of line noise components to remove (default=1). 156 nfft : int 157 FFT size (default=1024). 158 nkeep : int 159 Number of components to keep in DSS (default=None). 160 blocksize : int 161 If not None (default), covariance is computed on blocks of 162 ``blocksize`` samples. This may improve performance for large datasets. 163 show: bool 164 If True, show DSS results (default=False). 165 166 Returns 167 ------- 168 y : array, shape=(n_samples, n_chans, n_trials) 169 Denoised data. 170 artifact : array, shape=(n_samples, n_chans, n_trials) 171 Artifact 172 173 Examples 174 -------- 175 Apply to X, assuming line frequency=50Hz and sampling rate=1000Hz, plot 176 results: 177 >>> dss_line(X, 50/1000) 178 179 Removing 4 line-dominated components: 180 >>> dss_line(X, 50/1000, 4) 181 182 Truncating PCs beyond the 30th to avoid overfitting: 183 >>> dss_line(X, 50/1000, 4, nkeep=30); 184 185 Return cleaned data in y, noise in yy, do not plot: 186 >>> [y, artifact] = dss_line(X, 60/1000) 187 188 References 189 ---------- 190 .. [1] de Cheveigné, A. (2019). ZapLine: A simple and effective method to 191 remove power line artifacts [Preprint]. https://doi.org/10.1101/782029 192 193 """ 194 3058.7 MiB 0.0 MiB 1 if X.shape[0] < nfft: 195 print('Reducing nfft to {}'.format(X.shape[0])) 196 nfft = X.shape[0] 197 3058.7 MiB 0.0 MiB 1 n_samples, n_chans, n_trials = theshapeof(X) 198 3058.7 MiB 0.0 MiB 1 if blocksize is None: 199 3058.7 MiB 0.0 MiB 1 blocksize = n_samples 200 201 # Recentre data 202 3058.7 MiB 0.0 MiB 1 X = demean(X, inplace=True) 203 204 # Cancel line_frequency and harmonics + light lowpass 205 5370.5 MiB 2311.7 MiB 1 X_filt = smooth(X, sfreq / fline) 206 207 # X - X_filt results in the artifact plus some residual biological signal 208 # Reduce dimensionality to avoid overfitting 209 5370.5 MiB 0.0 MiB 1 if nkeep is not None: 210 cov_X_res = tscov(X - X_filt)[0] 211 V, _ = pca(cov_X_res, nkeep) 212 X_noise_pca = (X - X_filt) @ V 213 else: 214 7636.4 MiB 2265.9 MiB 1 X_noise_pca = (X - X_filt).copy() 215 7636.4 MiB 0.0 MiB 1 nkeep = n_chans 216 217 # Compute blockwise covariances of raw and biased data 218 7636.4 MiB 0.0 MiB 1 n_harm = np.floor((sfreq / 2) / fline).astype(int) 219 7636.4 MiB 0.0 MiB 1 c0 = np.zeros((nkeep, nkeep)) 220 7636.5 MiB 0.1 MiB 1 c1 = np.zeros((nkeep, nkeep)) 221 7777.3 MiB 0.0 MiB 4 for X_block in sliding_window_view(X_noise_pca, (blocksize, nkeep), 222 7636.5 MiB 0.0 MiB 2 axis=(0, 1))[::blocksize, 0]: 223 # if n_trials>1, reshape to (n_samples, nkeep, n_trials) 224 7636.5 MiB 0.0 MiB 1 if X_block.ndim == 3: 225 X_block = X_block.transpose(1, 2, 0) 226 227 # bias data 228 7637.0 MiB 0.5 MiB 1 c0 += tscov(X_block)[0] 229 7777.3 MiB 140.3 MiB 1 c1 += tscov(gaussfilt(X_block, sfreq, fline, fwhm=1, n_harm=n_harm))[0] 230 231 # DSS to isolate line components from residual 232 7778.1 MiB 0.8 MiB 1 todss, _, pwr0, pwr1 = dss0(c0, c1) 233 234 7778.1 MiB 0.0 MiB 1 if show: 235 import matplotlib.pyplot as plt 236 plt.plot(pwr1 / pwr0, '.-') 237 plt.xlabel('component') 238 plt.ylabel('score') 239 plt.title('DSS to enhance line frequencies') 240 plt.show() 241 242 # Remove line components from X_noise 243 7778.1 MiB 0.0 MiB 1 idx_remove = np.arange(nremove) 244 7778.1 MiB 0.0 MiB 1 X_artifact = matmul3d(X_noise_pca, todss[:, idx_remove]) 245 10056.1 MiB 2278.0 MiB 1 X_res = tsr(X - X_filt, X_artifact)[0] # project them out 246 # reconstruct clean signal 247 12322.0 MiB 2265.9 MiB 1 y = X_filt + X_res 248 249 # Power of components 250 12322.0 MiB 0.0 MiB 1 p = wpwr(X - y)[0] / wpwr(X)[0] 251 12322.1 MiB 0.0 MiB 1 print('Power of components removed by DSS: {:.2f}'.format(p)) 252 # return the reconstructed clean signal, and the artifact 253 14588.0 MiB 2265.9 MiB 1 return y, X - y ```

Computation time:

```bash 458451 function calls (455790 primitive calls) in 64.340 seconds Ordered by: cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 1 0.000 0.000 64.340 64.340 /memory_profiler.py:1140(wrapper) 1 0.313 0.313 64.017 64.017 /memory_profiler.py:715(f) 1 7.629 7.629 63.703 63.703 /meegkit/meegkit/dss.py:138(dss_line) 1473/272 1.176 0.001 36.986 0.136 {built-in method numpy.core._multiarray_umath.implement_array_function} 1 4.796 4.796 29.480 29.480 /meegkit/meegkit/utils/sig.py:279(gaussfilt) 2 0.000 0.000 24.615 12.308 /numpy/fft/_pocketfft.py:49(_raw_fft) 2 24.615 12.307 24.615 12.307 {built-in method numpy.fft._pocketfft_internal.execute} 1 0.000 0.000 14.306 14.306 <__array_function__ internals>:2(fft) 1 0.000 0.000 14.306 14.306 /numpy/fft/_pocketfft.py:122(fft) 1 0.000 0.000 10.309 10.309 <__array_function__ internals>:2(ifft) 1 0.000 0.000 10.309 10.309 /numpy/fft/_pocketfft.py:219(ifft) 1 4.082 4.082 10.040 10.040 /meegkit/meegkit/tspca.py:71(tsr) 1 0.000 0.000 7.492 7.492 /meegkit/meegkit/utils/sig.py:114(smooth) 100/1 0.001 0.000 7.492 7.492 <__array_function__ internals>:2(apply_along_axis) 100/1 1.763 0.018 7.490 7.490 /numpy/lib/shape_base.py:267(apply_along_axis) 99 0.008 0.000 6.035 0.061 /meegkit/meegkit/utils/sig.py:171(_smooth1d) 99 0.008 0.000 5.997 0.061 /scipy/signal/signaltools.py:1866(lfilter) 99 0.001 0.000 5.536 0.056 /scipy/signal/signaltools.py:2038() 99 0.001 0.000 5.536 0.056 <__array_function__ internals>:2(convolve) 99 0.003 0.000 5.534 0.056 /numpy/core/numeric.py:753(convolve) 99 5.531 0.056 5.531 0.056 {built-in method numpy.core._multiarray_umath.correlate} 4 3.269 0.817 4.477 1.119 /meegkit/meegkit/utils/denoise.py:10(demean) 4 0.000 0.000 4.386 1.097 /meegkit/meegkit/utils/covariances.py:170(tscov) 17 3.882 0.228 3.882 0.228 {method 'copy' of 'numpy.ndarray' objects} 6 0.000 0.000 2.974 0.496 /meegkit/meegkit/utils/matrix.py:211(multishift) 2 2.283 1.142 2.765 1.382 /meegkit/meegkit/utils/denoise.py:102(wpwr) 64 0.530 0.008 1.899 0.030 /meegkit/meegkit/utils/matrix.py:653(_check_data) 259 1.757 0.007 1.757 0.007 {method 'reduce' of 'numpy.ufunc' objects} 1 0.000 0.000 1.713 1.713 /meegkit/meegkit/utils/covariances.py:103(tsxcov) 50 0.000 0.000 1.500 0.030 /meegkit/meegkit/utils/matrix.py:472(theshapeof) 122 0.002 0.000 1.484 0.012 /numpy/core/fromnumeric.py:70(_wrapreduction) 2 0.000 0.000 1.201 0.600 <__array_function__ internals>:2(einsum) 2 0.000 0.000 1.201 0.600 /numpy/core/einsumfunc.py:997(einsum) 2 1.201 0.600 1.201 0.600 {built-in method numpy.core._multiarray_umath.c_einsum} 10 0.000 0.000 1.166 0.117 <__array_function__ internals>:2(dot) ```

Without your changes (#57 )

Memory:

```bash Line # Mem usage Increment Occurences Line Contents ============================================================ 138 3050.5 MiB 3050.5 MiB 1 @profile 139 def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None, 140 show=False): 141 """Apply DSS to remove power line artifacts. 142 143 Implements the ZapLine algorithm described in [1]_. 144 145 Parameters 146 ---------- 147 X : data, shape=(n_samples, n_chans, n_trials) 148 Input data. 149 fline : float 150 Line frequency (normalized to sfreq, if ``sfreq`` == 1). 151 sfreq : float 152 Sampling frequency (default=1, which assymes ``fline`` is normalised). 153 nremove : int 154 Number of line noise components to remove (default=1). 155 nfft : int 156 FFT size (default=1024). 157 nkeep : int 158 Number of components to keep in DSS (default=None). 159 blocksize : int 160 If not None (default), covariance is computed on blocks of 161 ``blocksize`` samples. This may improve performance for large datasets. 162 show: bool 163 If True, show DSS results (default=False). 164 165 Returns 166 ------- 167 y : array, shape=(n_samples, n_chans, n_trials) 168 Denoised data. 169 artifact : array, shape=(n_samples, n_chans, n_trials) 170 Artifact 171 172 Examples 173 -------- 174 Apply to X, assuming line frequency=50Hz and sampling rate=1000Hz, plot 175 results: 176 >>> dss_line(X, 50/1000) 177 178 Removing 4 line-dominated components: 179 >>> dss_line(X, 50/1000, 4) 180 181 Truncating PCs beyond the 30th to avoid overfitting: 182 >>> dss_line(X, 50/1000, 4, nkeep=30); 183 184 Return cleaned data in y, noise in yy, do not plot: 185 >>> [y, artifact] = dss_line(X, 60/1000) 186 187 References 188 ---------- 189 .. [1] de Cheveigné, A. (2019). ZapLine: A simple and effective method to 190 remove power line artifacts [Preprint]. https://doi.org/10.1101/782029 191 192 """ 193 3050.5 MiB 0.0 MiB 1 if X.shape[0] < nfft: 194 print('Reducing nfft to {}'.format(X.shape[0])) 195 nfft = X.shape[0] 196 3050.5 MiB 0.0 MiB 1 n_samples, n_chans, n_trials = theshapeof(X) 197 3050.5 MiB 0.0 MiB 1 if blocksize is None: 198 3050.5 MiB 0.0 MiB 1 blocksize = n_samples 199 200 # Recentre data 201 5316.5 MiB 2265.9 MiB 1 X = demean(X) 202 203 # Cancel line_frequency and harmonics + light lowpass 204 7628.2 MiB 2311.7 MiB 1 X_filt = smooth(X, sfreq / fline) 205 206 # Subtract clean data from original data. The result is the artifact plus 207 # some residual biological signal 208 9894.1 MiB 2265.9 MiB 1 X_noise = X - X_filt 209 210 # Reduce dimensionality to avoid overfitting 211 9894.1 MiB 0.0 MiB 1 if nkeep is not None: 212 cov_X_res = tscov(X_noise)[0] 213 V, _ = pca(cov_X_res, nkeep) 214 X_noise_pca = X_noise @ V 215 else: 216 12160.1 MiB 2265.9 MiB 1 X_noise_pca = X_noise.copy() 217 12160.1 MiB 0.0 MiB 1 nkeep = n_chans 218 219 # Compute blockwise covariances of raw and biased data 220 12160.1 MiB 0.0 MiB 1 n_harm = np.floor((sfreq / 2) / fline).astype(int) 221 12160.1 MiB 0.0 MiB 1 c0 = np.zeros((nkeep, nkeep)) 222 12160.1 MiB 0.0 MiB 1 c1 = np.zeros((nkeep, nkeep)) 223 12160.1 MiB 0.0 MiB 4 for X_block in sliding_window_view(X_noise_pca, (blocksize, nkeep), 224 12160.1 MiB 0.0 MiB 2 axis=(0, 1))[::blocksize, 0]: 225 # if n_trials>1, reshape to (n_samples, nkeep, n_trials) 226 12160.1 MiB 0.0 MiB 1 if X_block.ndim == 3: 227 X_block = X_block.transpose(1, 2, 0) 228 229 # bias data 230 4661.1 MiB -7499.0 MiB 1 X_bias = gaussfilt(X_block, sfreq, fline, fwhm=1, n_harm=n_harm) 231 6728.9 MiB 2067.8 MiB 1 c0 += tscov(X_block)[0] 232 6729.2 MiB 0.3 MiB 1 c1 += tscov(X_bias)[0] 233 234 # DSS to isolate line components from residual 235 6731.8 MiB -5428.2 MiB 1 todss, _, pwr0, pwr1 = dss0(c0, c1) 236 237 6731.8 MiB 0.0 MiB 1 if show: 238 import matplotlib.pyplot as plt 239 plt.plot(pwr1 / pwr0, '.-') 240 plt.xlabel('component') 241 plt.ylabel('score') 242 plt.title('DSS to enhance line frequencies') 243 plt.show() 244 245 # Remove line components from X_noise 246 6731.8 MiB 0.0 MiB 1 idx_remove = np.arange(nremove) 247 6754.8 MiB 23.0 MiB 1 X_artifact = matmul3d(X_noise_pca, todss[:, idx_remove]) 248 6822.9 MiB 68.1 MiB 1 X_res = tsr(X_noise, X_artifact)[0] # project them out 249 250 # reconstruct clean signal 251 12540.8 MiB 5717.9 MiB 1 y = X_filt + X_res 252 17072.3 MiB 4531.4 MiB 1 artifact = X - y 253 254 # Power of components 255 9515.1 MiB -7557.2 MiB 1 p = wpwr(X - y)[0] / wpwr(X)[0] 256 9515.3 MiB 0.2 MiB 1 print('Power of components removed by DSS: {:.2f}'.format(p)) 257 9515.3 MiB 0.0 MiB 1 return y, artifact ```

Computation time:

```bash 458499 function calls (455838 primitive calls) in 99.589 seconds Ordered by: cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 1 0.002 0.002 99.589 99.589 /memory_profiler.py:1140(wrapper) 1 1.202 1.202 99.242 99.242 /memory_profiler.py:715(f) 1 18.173 18.173 98.039 98.039 /meegkit/meegkit/dss.py:138(dss_line) 1478/277 1.039 0.001 47.523 0.172 {built-in method numpy.core._multiarray_umath.implement_array_function} 1 3.782 3.782 37.715 37.715 /meegkit/meegkit/utils/sig.py:279(gaussfilt) 2 0.001 0.000 33.862 16.931 /numpy/fft/_pocketfft.py:49(_raw_fft) 2 33.861 16.931 33.861 16.931 {built-in method numpy.fft._pocketfft_internal.execute} 1 4.969 4.969 20.961 20.961 /meegkit/meegkit/tspca.py:71(tsr) 1 0.000 0.000 18.838 18.838 <__array_function__ internals>:2(ifft) 1 0.000 0.000 18.838 18.838 /numpy/fft/_pocketfft.py:219(ifft) 4 7.452 1.863 15.295 3.824 /meegkit/meegkit/utils/denoise.py:10(demean) 1 0.000 0.000 15.024 15.024 <__array_function__ internals>:2(fft) 1 0.000 0.000 15.024 15.024 /numpy/fft/_pocketfft.py:122(fft) 22 11.456 0.521 11.456 0.521 {method 'copy' of 'numpy.ndarray' objects} 1 0.000 0.000 7.932 7.932 /meegkit/meegkit/utils/sig.py:114(smooth) 100/1 0.001 0.000 7.932 7.932 <__array_function__ internals>:2(apply_along_axis) 100/1 1.763 0.018 7.931 7.931 /numpy/lib/shape_base.py:267(apply_along_axis) 5 0.001 0.000 6.655 1.331 /meegkit/meegkit/utils/matrix.py:497(fold) 99 0.008 0.000 6.465 0.065 /meegkit/meegkit/utils/sig.py:171(_smooth1d) 99 0.008 0.000 6.427 0.065 /scipy/signal/signaltools.py:1866(lfilter) 99 0.001 0.000 5.971 0.060 /scipy/signal/signaltools.py:2038() 99 0.001 0.000 5.971 0.060 <__array_function__ internals>:2(convolve) 99 0.003 0.000 5.969 0.060 /numpy/core/numeric.py:753(convolve) 99 5.966 0.060 5.966 0.060 {built-in method numpy.core._multiarray_umath.correlate} 2 3.924 1.962 5.269 2.634 /meegkit/meegkit/utils/denoise.py:93(wpwr) 4 0.001 0.000 5.217 1.304 /meegkit/meegkit/utils/covariances.py:170(tscov) 6 0.000 0.000 3.875 0.646 /meegkit/meegkit/utils/matrix.py:211(multishift) 64 0.454 0.007 2.657 0.042 /meegkit/meegkit/utils/matrix.py:652(_check_data) 50 0.000 0.000 2.263 0.045 /meegkit/meegkit/utils/matrix.py:472(theshapeof) 259 1.915 0.007 1.915 0.007 {method 'reduce' of 'numpy.ufunc' objects} 1 0.000 0.000 1.721 1.721 /meegkit/meegkit/utils/covariances.py:103(tsxcov) 385 1.698 0.004 1.698 0.004 {built-in method numpy.zeros} 122 0.002 0.000 1.627 0.013 /numpy/core/fromnumeric.py:70(_wrapreduction) 73 0.000 0.000 1.570 0.022 <__array_function__ internals>:2(iscomplex) 73 0.002 0.000 1.569 0.021 /numpy/lib/type_check.py:210(iscomplex) 2 0.000 0.000 1.223 0.612 <__array_function__ internals>:2(einsum) 2 0.000 0.000 1.223 0.612 /numpy/core/einsumfunc.py:997(einsum) 2 1.223 0.612 1.223 0.612 {built-in method numpy.core._multiarray_umath.c_einsum} 15 0.000 0.000 1.075 0.072 /meegkit/meegkit/utils/matrix.py:511(unfold) 10 0.000 0.000 1.027 0.103 <__array_function__ internals>:2(dot) ```
eort commented 2 years ago

Okay, that is surprising.

That's it, right?

Pytest is passing now locally. The issue was the change to matrix.py

nbara commented 2 years ago

I've made some changes. The only thing I reversed from your code is the multiple X-X_filt, which cannot possibly beneficial in terms of computations.

If this is ok with you then let's merge this into #57

eort commented 2 years ago

Sure, sounds good.