Closed LinearParadox closed 1 week ago
I'm not familiar with torch-neuron, what sort of changes would be necessary in order to enable this?
I think there would likely have to be some analogous code added. For example something like:
if use_neuron:
Neuron training code
else:
normal training code
The API seems pretty analogous to PyTorch:
I'm not super experienced with torch, but I can also try to dig in after this week to try and see if it's a trivial modification or entails a larger redesign.
The one major difference that might pose an issue is that neuron builds graphs lazily, while PyTorch doesn't. Not sure how impactful this will be practically though
Let me know if you're able to look into this! Would be happy to take a PR if it's a small modification. If it seems like it's going to be a larger redesign, I think it would make sense for us to discuss it internally before anything is implemented.
Torch neuron is a PyTorch architecture that enables it to use AWS based Trainium and Inferentia gpu instances. Since these are somewhat cheaper, especially for large models that may be a little too large for GPUs such as T4s, but also not worth an A100. It would be a nice addition