mne-tools / mne-connectivity

Connectivity algorithms that leverage the MNE-Python API.
https://mne.tools/mne-connectivity/dev/index.html
BSD 3-Clause "New" or "Revised" License
68 stars 34 forks source link

RAM Usage in VAR Models #28

Closed adam2392 closed 2 years ago

adam2392 commented 3 years ago

I suspect that using sklearn here is very memory inefficient. To compute a Ridge you need to assemble X.T @ X and X.T @ y and then solve a linear system. Time series packages assemble X.T @ X and X.T @ y on the fly by looping of the signals time stamps. This avoids a huge matrix allocation for storing X (data with all delays etc.)

_Originally posted by @agramfort in https://github.com/mne-tools/mne-connectivity/pull/23#discussion_r688336154_

agramfort commented 3 years ago

Ridge regression solution with regularsation reg reads:

coef = linalg.solve(X.T @ X + reg * np.eye(n_features), X.T @ y)

so to solve it you just need X.T @ X and X. T @ y

X.T @ X is nothing more than

\sum_i X[i, :].T X[i, :] # sum of rank 1 n_features x n_features matrices.

here the sum is over samples which are time points in your case.

Does it help @adam2392 ?

adam2392 commented 3 years ago

Ah yeah I see what you mean.

So I see that this will help definitely when model order is high. I suspect though if the RAM usage isn't too high, then having X and X.T is not too bad and possibly faster?

What would be a good cutoff in estimated RAM usage to automatically do the for loop over samples to construct the X.T @ X matrix?

Or should we allow users to decide with a kwarg?

agramfort commented 3 years ago

I would have only one implementation using the online way. It would be a lot easier to test. And it should be systematically faster.

adam2392 commented 3 years ago

I'm working on this rn, and I prefer copying the code over from statsmodels, rather then having statsmodels be a dependency just cuz the statsmodels API is more R, then Python. Moreover, the statsmodels VAR code is not very well-documented in their docstrings.

However, I'll probably copy a decent chunk of code if I also allow us to do "order selection" using information criterion like AIC/BIC/etc.

Is this okay, or should I just hack together a wrapper of statsmodels to handle the RAM issue?

@agramfort @larsoner

agramfort commented 3 years ago

I would copy the relevant code and avoid a dependency on statsmodels but I would also add a test that we don't deviate from statsmodels in case they fix a bug we detect it.

adam2392 commented 3 years ago

@agramfort I'm looking at the statsmodels code that I repurposed and I realized they don't actually do the above, where they loop over samples/channels: https://www.statsmodels.org/stable/_modules/statsmodels/tsa/vector_ar/var_model.html#VAR

They actually explicitly form the Z predictor matrix.

Is there an example package you were thinking of that actually does the above but for lags as well? Handling the VAR(1) case and directly forming Z.T @ Z and Z.T @ Y is easy enough, but gets a bit more involved w/ lags.

adam2392 commented 2 years ago

Ridge regression solution with regularsation reg reads:

coef = linalg.solve(X.T @ X + reg * np.eye(n_features), X.T @ y)

so to solve it you just need X.T @ X and X. T @ y

X.T @ X is nothing more than

\sum_i X[i, :].T X[i, :] # sum of rank 1 n_features x n_features matrices.

here the sum is over samples which are time points in your case.

Does it help @adam2392 ?

@agramfort I'm trying to implement the for loop here you suggested to see if it improves the RAM usage for high n_channels/n_samples.

Do you have an idea/tip on how to do it when lags are present? I got stuck wondering if this is possible with lags.