pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.07k stars 377 forks source link

Configurable sharding strategy for distributed finetune #1014

Closed tambulkar closed 1 month ago

tambulkar commented 4 months ago

Can we add an option to the config to make the sharding strategy configurable for the distributed fine tune runs?

kartikayk commented 4 months ago

Thank you for opening this issue!

This has been on our list of TODOs since different sharding strategies expose different memory-perf trade-offs. Would you be open to making this change? We'd be happy to review and land this. Should be reasonably straightforward:

tambulkar commented 4 months ago

Yup I'll go ahead and implement this

pbontrager commented 4 months ago

There's potentially some extra considerations here. Specific models should have good default wrapping strategies, which we don't currently have a way of supporting. Secondary is there should be a direct way to overwrite that default. For overriding the FSDP wrapping, you might run into config nesting limits, which would require using _get_component_from_path as well.

@ebsmothers do you have any thoughts on this?

pbontrager commented 4 months ago

Sorry, my previous comment was about making the wrap_policy configurable, I misread the question. Please ignore.