Open richardkoehler opened 8 months ago
Just to add, was discussed briefly with @adam2392 in https://github.com/mne-tools/mne-connectivity/pull/163#discussion_r1428342308
I'm in favor of supporting this. @larsoner do you have any opinions about how to expose this option? Some ideas off the cuff:
mne_connectivity.set_precision()
(as @richardkoehler suggested)mne_connectivity.set_config("precision", "half")
(similar to mne.set_config()
, that is backed by a .json
file and can also optionally set env variables)@verbose
that could inject a precision argument into every function where it made sense to do so.I'm sure there are other possibilities too...
Just to give an idea of something I used for a package I wrote.
I had a class that was initialised by default with double precision.
The initialised object would be imported in any class involving computations (example) and used to specify dtype
(example).
from xyz._utils import _precision
...
my_real_var = np.mean([x, y], dtype=_precision.real)
my_complex_var = np.mean([x, y], dtype=_precision.complex)
If you wanted to change the precision you could call some set_precision()
function (example) and compute whatever.
from xyz import set_precision
set_precision("single") # real: np.float32; complex: np.complex64
set_precision("double") # real: np.float64, complex: np.complex128
Not saying it's the best method, but worked for me in the past.
Global states usually end up being a pain. I would start small -- concretely where is this needed/useful? For example if it's in envelope_connectivity
you could add a dtype="float"
there, which means "always use float". If a user passes dtype=None
it means "use the native dtype of the data passed". Or they can pass dtype=np.float32
for example that would mean "always cast to float32". And then we make sure that the computation functions actually support the given dtype after doing the casting that the user requested. For most (all?) of those it's going to be just float, np.float32
I think.
Then if this ends up being useful in a lot of places we can make nice ways of handling dtype=...
by using private functions. Maybe someday a decorator. But starting with one then generalizing seems like the right approach, and then we see how much additional refactoring/engineering ends up being worth it.
Describe the problem
When performing analysis with high sampling frequencies and with both time- and frequency-resolved connectivity (with the wavelet method), I frequently run in to memory (and time) limits, even when running on the high-performance clusters (with up to 200GB RAM). This is of course also related to the fact that I work with high sampling frequencies and with bootstrapping methods, and certainly related to the fact that I use the multivariate methods like MVGC where large matrices have to be created for computation. I have found that one part of the solution, next to reducing number of sampled frequencies and running analyses in sequence, was to explicitly use a lower-precision data type in the source code (e.g. np.complex64 instead of complex128 and np.float32 instead of float64). The results did not change significantly for my analyses, but memory was often almost cut in half and computation time was also reduced.
Describe your solution
It would be awesome to have the option to reduce the precision of the calculations if desired (default would obviously remain the highest possible precision). This could be for example be implemented in a function like
mne_connectivity.set_precision("full") # "half"
or alternatively be more specific, e.g.
I hope this is something that could be considered! Maybe @tsbinns you have some thoughts on this, or on potential implementation details?