pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.39k stars 448 forks source link

v0.3 regression, full_finetune_distributed slower ? #1718

Open Delaunay opened 1 month ago

Delaunay commented 1 month ago

The recipe full_finetune_distributed Appear to be much slower in v0.3 than v0.2.1

Everything seems to work as usual, but my job that used to work in v0.2.1 time out in v0.3.0.

I don't have much detail yet, but maybe as you are more familiar with the code base you could have an idea already based on what changed recently!

joecummings commented 1 month ago

Can you share a few more details around which models you're using, size of dataset, machine type?

Off the very top of my head, not sure what would be going on.

Delaunay commented 1 month ago

I tried on

joecummings commented 1 month ago

Hey @Delaunay - I looked into this and was able to repro! Unfortunately, still digging into the root cause, but a quick fix is to upgrade your PyTorch version to the nightlies.

pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/

After doing this, training should be fast again:

Screenshot 2024-10-02 at 1 01 42 PM
ebsmothers commented 1 month ago

I think @felipemello1 added some warnings about this in #1733, seems like the fix is to run on PyTorch nightlies here. Given that we have a resolution for this I am gonna close the issue, but @Delaunay if you are not unblocked please feel free to reopen.

Delaunay commented 1 month ago

I downgraded to 0.2.1 while I wait for pytorch to release a new version, in my case I cannot use nightlies.

But if it OOM why did I not see the error being raised ? Is it linked to the CPU offloading where things get moved out to CPU to avoid OOM and it gets super slow but eventually OOM ?

ebsmothers commented 1 month ago

@Delaunay since you can't use the nightlies I'll reopen this. The main change is that between 0.2.1 and 0.3.0 we moved onto FSDP2. Since this is a relatively new feature, it's likely that there have been some optimizations made since 2.4 (I can't pinpoint anything offhand but can dig in further here). Out of curiosity, where is it slow? Is it during training, checkpoint load, or somewhere else? And do you see similar slowdowns when running on smaller models (i.e. ones where we aren't doing CPU offload)?