scverse / scvi-tools

Deep probabilistic analysis of single-cell and spatial omics data
http://scvi-tools.org/
BSD 3-Clause "New" or "Revised" License
1.16k stars 341 forks source link

Add torch neuron support for scvi tools #2619

Closed LinearParadox closed 1 week ago

LinearParadox commented 3 months ago

Torch neuron is a PyTorch architecture that enables it to use AWS based Trainium and Inferentia gpu instances. Since these are somewhat cheaper, especially for large models that may be a little too large for GPUs such as T4s, but also not worth an A100. It would be a nice addition

martinkim0 commented 3 months ago

I'm not familiar with torch-neuron, what sort of changes would be necessary in order to enable this?

LinearParadox commented 3 months ago

I think there would likely have to be some analogous code added. For example something like:

if use_neuron:
    Neuron training code
else:
    normal training code

The API seems pretty analogous to PyTorch:

https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/programming-guide/training/pytorch-neuron-programming-guide.html#pytorch-neuronx-programming-guide

I'm not super experienced with torch, but I can also try to dig in after this week to try and see if it's a trivial modification or entails a larger redesign.

The one major difference that might pose an issue is that neuron builds graphs lazily, while PyTorch doesn't. Not sure how impactful this will be practically though

martinkim0 commented 3 months ago

Let me know if you're able to look into this! Would be happy to take a PR if it's a small modification. If it seems like it's going to be a larger redesign, I think it would make sense for us to discuss it internally before anything is implemented.