tatsu-lab / stanford_alpaca

Code and documentation to train Stanford's Alpaca models, and generate the data.
https://crfm.stanford.edu/2023/03/13/alpaca.html
Apache License 2.0
29.39k stars 4.03k forks source link

The OOM problem caused by the Transformers version #278

Open kiseliu opened 1 year ago

kiseliu commented 1 year ago

A month ago, I train the alpaca with 4 A100 GPUs (each 80G) and per_device_train_batch_size=4. Here transformers==4.28.1.

Today I retrain the alpaca with the same hardwares and the same code, but there has an OOM problem. It can work only when per_device_train_batch_size=1. Through the wandb, I found the transformers version in my virtual environment is transformers==4.31.dev0. Then I change the transformers version to 4.28.1, I can run train the alpaca with per_device_train_batch_size=4.

Anyone has the idea?

yxchng commented 1 year ago

are you using depacoda/llama-7b-hf? and the exact same training command as in readme?

kiseliu commented 1 year ago

are you using depacoda/llama-7b-hf? and the exact same training command as in readme?

yes, the exact same training command as in readme