mehta-lab / waveorder

Wave optical models and inverse algorithms for label-agnostic imaging of density & orientation.
BSD 3-Clause "New" or "Revised" License
12 stars 3 forks source link

Support GPUs via `torch` #144

Open talonchandler opened 11 months ago

talonchandler commented 11 months ago

waveorder's simulations and reconstructions are moving to torch following the new models structure, and along the way we decided to temporarily drop GPU support in favor of prioritizing the migration.

We would like to restore GPU support for many of our operations, especially our heaviest reconstructions.

@ziw-liu, can you comment on the easiest path you see to GPU support?

ziw-liu commented 11 months ago

Conceptually if every operation is functional (as in torch.nn.functional) then the GPU switch won't even be necessary -- the computation will automatically happen on the device where the input tensor is stored on, and internally created tensors can use tensor(..., device=input.device). I don't think it will be hard (torch is a GPU-first library after all), we just need to carefully test and fix things.

Edit: simple example of an tensor-in-tensor-out function that is device-agnostic.