mlcommons / GaNDLF

A generalizable application framework for segmentation, regression, and classification using PyTorch
https://gandlf.org
Apache License 2.0
150 stars 78 forks source link

[FEATURE] Update PyTorch dependency #884

Open sarthakpati opened 1 month ago

sarthakpati commented 1 month ago

Is your feature request related to a problem? Please describe.

Now that PyTorch 2.3.0 has been out for a while, does it make sense to make the switch? There are a few backward incompatible changes [ref] which potentially relate to the work being done by @Geeks-Sid, so I will definitely wait for his comments.

Describe the solution you'd like

N.A.

Describe alternatives you've considered

N.A.

Additional context

Comments/suggestions, @VukW, @szmazurek?

VukW commented 3 weeks ago

Sorry as I am not proficient enough neither in the latest pytorch changes, nor in how GaNDLF uses distributed training. Do we have any tests for multi-gpu training? Cannot find any. If yes, then maybe just running a tests should be enough to ensure new version is ok for us. Anyway, once in the future we would have to update dependency, so why not now

sarthakpati commented 3 weeks ago

Unfortunately, we do not have any GPU tests right now. 😞

I am fine with updating the dependency right now, but I would like to get the opinion of other developers/contributors/maintainers. 😄

szmazurek commented 3 weeks ago

Hey, From my perspective why not, probably would be a matter re-running tests and making some corrections, as @VukW says would need to happen anyways.

sarthakpati commented 3 weeks ago

Sounds good, thanks!

Just waiting for @Geeks-Sid to respond and then we can start.

Geeks-Sid commented 3 weeks ago

Looks like the backwards compatibility issue does not affect us, however the tests might be good to be run on GPU's . We are good to go, but is there any issue in staying at current version?

sarthakpati commented 3 weeks ago

the tests might be good to be run on GPU's

Agreed - I am in discussion with a couple of CI providers to give us some extremely limited free GPU compute. Let's see how it goes.

is there any issue in staying at current version?

Nothing specific. Just that moving to the last stable release ensures that we aren't too far back in terms of ensuring latest bug fixes from PyTorch getting propagated forward. And since we will be making a jump with the new API branch anyway, I figured it might make sense to go to the latest one.

szmazurek commented 1 week ago

Dears, ragading the torch version - from version 2.2, torch has a built-in flash attention mechanism implemented, see: https://pytorch.org/blog/pytorch2-2/ . @sarthakpati mentioned that in the future we may integrate flash attention to speed up some models that employ the attention, this would be also useful regarding the synthesis module, where some diffusion models use it too. So, considering version updates, we may look directly into 2.2 as that solves both version update and flash attention.