aehrc / cxrmate

CXRMate: Longitudinal Data and a Semantic Similarity Reward for Chest X-Ray Report Generation
https://huggingface.co/aehrc/cxrmate
Apache License 2.0
14 stars 3 forks source link

Validation hangs after first validation sample when using deepspeed stage 3 with multi-gpu #15

Open AceMcAwesome77 opened 2 months ago

AceMcAwesome77 commented 2 months ago

Hi, thanks again for this excellent repo. The model trains fine for me on a multi-gpu system using deepspeed stage 2 with or without offload, and I see the expected training/validation time reduction. However, when I train with deepspeed stage 3, the training_step progresses fine, but the model hangs indefinitely immediately after the validation step starts and one validation sample is completed. By adding logging, I was able to determine that the model generates tokens one-by-one for the multiple GPUs simulataneously as expected, but as soon as one GPU is done generating the sequence for its validation sample, all the other GPUs hang immediately afterwards. This makes me suspect the issue has something to do with synchronization across the GPUs, where it will not allow the other GPUs to continue to go through calls to the forward function which generates tokens one-by-one during validation here, since one GPU has already finished that part. However I don't understand why this issue would manifest only during stage 3 and not also during stage 2.

Has anyone else successfully trained this model, including the validation step, using deepspeed stage 3? I've tried adjusting some of the deepspeed_config parameters, but no success yet. Thanks!

AceMcAwesome77 commented 2 months ago

I think this thread may be relevant to the problem:

https://github.com/microsoft/DeepSpeed/issues/860

I thought I could potentially solve this issue by setting eos_token_id=None in the self.encoder_decoder.generate() function and letting all GPUs generate the same number of tokens during validation_step, but then I got a new error that was something like "disagreement between rank0 and rank4" that I don't know how to fix. But I do think stas00 is probably correct about the issue in that thread.