SpikeInterface / spikeinterface

A Python-based module for creating flexible and robust spike sorting pipelines.
https://spikeinterface.readthedocs.io
MIT License
521 stars 186 forks source link

Handling torch params globally with set_global_torch_kwargs ? #3337

Open yger opened 2 months ago

yger commented 2 months ago

More and more methods are now relying on torch (motion estimation, some peak detectors), and I've made a working implementation of the SVD convolution used in the the matching engines (wobble, circus-omp) that can also use it and be faster. However, while trying to push that into the main, I'm struggling with some problems that I think are worth looking into.

The problem is that maybe it would be good to have some mecanism, at the global level of spikeinterface, that would allow us to configure torch (either if it should be used or not, and on which device). This is particularly important because when using torch with a gpu device, one need to spawn processes, while it could still be only forked if device is cpu. Currently, for example in the peak detectors (as it has been done I think be alessio), the mp_context is hardcoded for some nodes while it could (should?) be chosen depending on the torch context (if we are using it or not, and eventually how we are using it).

So why not, similarly to the global job_kwargs in core/globals, think about a global torch_kwargs dictionary (with appropriate methods set/get_global_torch_kwargs) that will have some keys such as {"use_torch" : bool, "device" : str}?

If we had such a dict, then maybe we could make some functions like split_torch_kwargs() (similarly to job_kwargs) and then ease the signature of all functions that might rely on torch.

method_kwargs, job_kwargs = split_job_kwargs(kwargs)
method_kwargs, torch_kwargs = split_torch_kwargs(method_kwargs)

What do you think?

yger commented 2 months ago

One other option would be to add torch_device in job_kwargs, and have mutiple options (None (no torch), "auto", "cpu", "gpu"). The good point about that is that we do not add any functions, and still spawn/fork processes could be determined on the fly given the context. But this means adding a job_keys

JoeZiminski commented 1 month ago

+1 I think this is a cool idea! I'm not familiar enough with the job kwargs setup to know what API is best, though I guess if the number of torch-related kwargs are small this could be a subset of job_kwargs . Would it only ever require adding the context?

Can the torch settings and usual way of setting multiprocessing in spike interface ever clash? e.,g. is setting n_jobs set_global_job_kwargs respected when running torch on CPU or does it do it's own thing?