pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.25k stars 417 forks source link

Model distillation use-case #1430

Open ScottHoang opened 2 months ago

ScottHoang commented 2 months ago

Hi team! Fantastic work so far; I genuinely enjoy working with this repository. I have a question regarding the distillation use case. Lets say for every transformer block, I generate an internal _distillation_loss. In a non FSDP world, I would traverse through the model with model.named_modules and gather all my internal losses before adding them to my objective loss in my training loop. Would this still be the case in FSDP?

ebsmothers commented 2 months ago

Hi @ScottHoang thanks for creating the issue. Distillation is actually something we're currently looking into, so this is a timely question. How is the distillation loss calculated exactly? One naive suggestion is to just return all the hidden states of the transformer, then use those to calculate the loss in the same place you do the usual cross-entropy. In that case I feel like the necessary changes to how the model is sharded vs usual SFT would be relatively minimal (since you don't have to worry about going into each layer and manually gathering the loss). I may be oversimplifying things though, if you can tell me a bit more about the loss and training setup I can give more detailed suggestions.

ScottHoang commented 2 months ago

hi @ebsmothers thank you for responding. In my use-case I am actively reducing x number of layers of Y. so something like Input -> L1 -> L2 -> L3 -> L4 -> output with Input -> L1 -> M -> L4 -> output I am trying to match the output of M with -L2->L3. I want to do it internally so I don't have to keep a copy of the intermediate activations, hence looping through the model's modules. Would this setting still works in FSDP?

ebsmothers commented 2 months ago

Hi @ScottHoang thanks for the additional info. Are you calling backward multiple times then (e.g. once on each loss)? I assume you would need to do this to free the activations. Also if activation memory is a concern can you just recompute on the backward? This would be another way to save much of that memory (albeit at the cost of some perf) that should work with FSDP