pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.35k stars 440 forks source link

implement activation offloading and opt_in_bwd in knowledge_distillation recipes #1959

Open felipemello1 opened 2 weeks ago

felipemello1 commented 2 weeks ago

Current knowledge distillation recipes don't have support for activation offloading and opt_in_bwd.

The implementation should be similar to the one in other recipes, like full_finetuning_distributed.

After enabling it in the recipe, it should also be enabled in the configs related to KD.

PRs with reference implementation: activation offloading: https://github.com/pytorch/torchtune/pull/1847 opt_in_bwd implementation: https://github.com/pytorch/torchtune/pull/1833

KD recipes: https://github.com/pytorch/torchtune/blob/main/recipes/knowledge_distillation_single_device.py https://github.com/pytorch/torchtune/blob/main/recipes/knowledge_distillation_distributed.py

after implementing it, run it with the flag on/off and plot the graphs of loss/memory/words per second. The easier way is to add the wandb logger to the config.

to update configs in bulk, you can use the script here: https://github.com/pytorch/torchtune/pull/1954

AnuravModak commented 5 days ago

Hi @felipemello1 is it available for external contributors? if yes kindly assign it to me will look into it. Thanks in advance!

felipemello1 commented 4 days ago

@AnuravModak yes! Any issue with "community help wanted" is something that we think would be a great fit for external contributors. Feel free to ask me questions and just submit a PR.

More info on contributing here, but the TLDR: main/CONTRIBUTING.md

1) Fork torchtune: pytorch/torchtune then clone from it: git clone github.com/torchtune.git

2) install dependencies

cd torchtune conda create -n torchtune python=3.11 conda activate torchtune pip install --pre --upgrade torch torchvision torchao --index-url download.pytorch.org/whl/nightly/cu124 pip install -e ".[dev]" pre-commit install

3) create the PR