magic-research / PLLaVA

Official repository for the paper PLLaVA
609 stars 42 forks source link

Hardware specs? #1

Open tkarthikeyan132 opened 7 months ago

tkarthikeyan132 commented 7 months ago

Great work! Please mention the hardware specs for your training. Will 2 A100s 40GB sufficient for training 13B model? Also please suggest if there are any flash attention flags in your code and where to turn them off, or ways to train/infer your model on low scale compute such as V100s.
Thanks.

cathyxl commented 7 months ago

Hi @tkarthikeyan132 , we used 8 A100 80G gpus to train 13B under the deepspeed zero3 mode. In your case of 2 A100 40G, you should reduce the batch size for each gpu to 4, and set gradient accumulation to 16 to reach the total batch size 128. For the code of flash attention, we already use it in https://github.com/magic-research/PLLaVA/blob/8260a1c80b472d21eb2dc58dea5cf4a2203fc5ac/models/pllava/modeling_pllava.py#L294. If you want to use V100, please try to lower the batch size more.

hn18001 commented 7 months ago

@cathyxl How many gpus are used to train 34B model? And what about the total training time.

cathyxl commented 7 months ago

Hi @hn18001, for 34b, we used 16 A100 GPUs to run the training, it took around 3 days to finish one epoch.

mariotrivinor commented 6 months ago

Hello @tkarthikeyan132,

To avoid problems with GPUs that are not of the Ampere architecture, such as the V100, you can change the following line in PLLaVA/models/pllava/modeling_pllava.py:

self.language_model = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype, attn_implementation="flash_attention_2")

to:

self.language_model = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype)

This change removes the specification of attn_implementation.

Namzakku commented 5 months ago

Hi @cathyxl Can I ask the amount of gpu and time you used to train for the 7b?