yxli2123 / LoftQ

MIT License
202 stars 19 forks source link

Method fails on Gemma-7B model #30

Open Ther-nullptr opened 6 months ago

Ther-nullptr commented 6 months ago

Hello, I have tried your method on gemma-7b model. I found that this method is work on gsm-8k dataset, but this fails on wikitext-2 dataset. This is my training log:

[WARNING|logging.py:329] 2024-05-15 10:23:39,953 >> `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
/home/yujin-wa20/miniconda3/envs/gact/lib/python3.9/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=F
alse explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_r
eentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
{'loss': 51.6576, 'grad_norm': 470.694580078125, 'learning_rate': 0.0003, 'epoch': 0.11}     
{'loss': 47.9403, 'grad_norm': 437.8383483886719, 'learning_rate': 0.00029890633111470807, 'epoch': 0.21}                                                                                 
{'loss': 23.9947, 'grad_norm': 42.98173904418945, 'learning_rate': 0.00029564127261390776, 'epoch': 0.32}                                                                                 
{'loss': 23.057, 'grad_norm': 132.80783081054688, 'learning_rate': 0.00029025243640281223, 'epoch': 0.43}                                                                                 
{'loss': 21.0726, 'grad_norm': 24.4749755859375, 'learning_rate': 0.0002828184038479814, 'epoch': 0.53}                                                                                   
 19%|███████████████████████████▏                                                                                                                       | 5/27 [39:06<2:35:50, 425.04s/it]

I didn't change the original code. Do you know why?

yxli2123 commented 6 months ago

The loss at the first epoch looks already high. There may be problems in the initialization. Could you provide the code when you load the model? Since we haven't provided the official Gemma-7B LoftQ checkpoint, could you also provide the code how you obtain the quantized backbone and the LoRA adapter by quantize_and_save.py ?