ContextualAI / gritlm

Generative Representational Instruction Tuning
https://arxiv.org/abs/2402.09906
MIT License
479 stars 33 forks source link

QLoRa / LoRa issue #20

Closed Hisarlik closed 3 months ago

Hisarlik commented 3 months ago

Hello, you have done an incredible job. I've read in the README file that QloRa and loRa integration is not well-tested.

I tried the getting started example for embeddings adding qlora and lora and the result is always:

TypeError: GemmaModel.forward() got an unexpected keyword argument 'labels'

I've tried other models and the output is the same.

Muennighoff commented 3 months ago

Does it work without qlora/lora? This seems like a problem with labels being passed to forward - what attn are you using?

Hisarlik commented 3 months ago

I've tried cccc, bb and cc. I've tried with tiny-mistral and other models like gemma 2b.

torchrun --nproc_per_node 1 -m training.run --output_dir test_path --model_name_or_path openaccess-ai-collective/tiny-mistral --train_data training/toy_data/toy_data_embedding.jsonl --learning_rate 1e-5 --num_train_epochs 5 --per_device_train_batch_size 2 --dataloader_drop_last True --normalized True --temperature 0.02 --query_max_len 32 --passage_max_len 128 --train_group_size 2 --mode embedding --attn cccc --lora True

Muennighoff commented 3 months ago

Does it work without qlora/lora?

Hisarlik commented 3 months ago

yes, sorry. It works without lora

Muennighoff commented 3 months ago

I think it is because LoRA rewraps the model and then the label kwarg is not passed through or something; I don't have time to debug it atm but it shouldn't be too complicated - would be amazing if you open a PR in case you fix it :)

Hisarlik commented 3 months ago

Of course. I'll give it a try. Congratulations again for your work.

Hisarlik commented 3 months ago

I have created a PR: https://github.com/ContextualAI/gritlm/pull/21

I have tested with the embedding example. I haven't found the reason for this behaviour in the PEFT library. By removing the specific task parameter the training works.