pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
BSD 3-Clause "New" or "Revised" License
3.55k stars 291 forks source link

Apply QLoRA to output projections and token embedding #1000

Open rohan-varma opened 1 month ago

rohan-varma commented 1 month ago

Currently, we don't apply QLoRA to either the output projection or token embeddings. There's no great reason to not apply quantization to output projections, we simply don't do this due to limitations in torchao (quantized large weights somehow taking up more memory than unquantized). We should begin to quantize output proj once this is fixed in AO.

Code pointer to where we skip quantizing output projection in torchtune: https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama2/_component_builders.py#L235-L239.

On the token embedding quantization, applying LoRA to embeddings is relatively unexplored research wise, but there have been some requests i.e. in https://github.com/huggingface/peft/issues/349. We might want to explore this for even more memory saving.

Optimox commented 1 month ago

@rohan-varma is it ok to simply unfreeze the embedding layer before fine-tuning with LoRA with something like this :

for param in model.tok_embeddings.parameters():
    param.requires_grad = True

?

ebsmothers commented 1 month ago

Hi @Optimox is your idea to just fine-tune the embedding layer directly without any additional LoRA weights? If so this will work. If you're doing a distributed run you may need to be a bit careful about the FSDP wrapping though (I don't think anything will break, but there's the possibility of using extra memory if you don't change the wrapping). Btw adding a proper LoRAEmbedding layer is still pretty high on our wishlist, if you're interested in helping out this'd be a great contribution.

Optimox commented 1 month ago

@ebsmothers yes I've tried to simply unfreeze the embeddings and train the rest of the model with LORA. It seems to be working ok on a single machine. I'm just wondering if this is a good practice or if there is a good reason to keep the embeddings frozen when finetuning with LORA ? About the LoRAEmbedding will the initial embedding layer be a bottle neck at some point in terms of memory usage ?

ebsmothers commented 3 weeks ago

@Optimox sorry somehow your message slipped through the cracks here.

I'm not sure what the best practice is here in terms of model quality. One thing that could matter is whether you are trying to learn new embeddings. E.g. if you have a new special token in your tokenize and it's untrained, making the full embedding trainable during fine-tuning may be a good way to learn richer information about that special token (whereas for LoRA you would wind up learning a much lower-dimensional representation of this token than for other tokens that the model has already been pretrained on).

Re memory usage, say we use a LoRA rank of 8. Taking Llama3-8B as an example, the vocab size is about 128k and the embed dim is about 4k. This means if you fully fine-tune the embedding matrix you will have 128k 4k gradients, in bf16 this would be about 1 GB. But with LoRA you would only have 128k 8 + 4k 8, which is more like 2-3 MB of gradients (all math here is very* approximate). So the memory savings are nontrivial by applying LoRA to the embedding, just depends on how much memory you are using elsewhere and how much you have available.