Closed Reza-esfandiarpoor closed 8 months ago
Can you identify where exactly it hangs? Maybe @edbeeching or @vwxyzjn have tested this in #850.
I have faced the same issue, it stuck at https://github.com/huggingface/trl/blob/20428c48ba651b2bc11b2a73e9eb9568b1af3f96/trl/trainer/ppo_trainer.py#L1347-L1351
I put a ValueError after L1346 and it does raise the error
I put a ValueError after L1351 and it never raises the error
Thanks! cc @lvwerra
Thanks for raising the issue. I can reproduce this issue. I think the problem is gather_object(batch)
function call. I remember it hangs with nested inputs, so I changed it to deal with flattened inputs, not sure why it stopped working again... Investigating further.
Ok so the issue is that we are calling gather_object
within the if self.accelerator.is_main_process:
block, so other processes never call the gather_object
and the main process hangs. @lvwerra #1177 has fixed it.
Can confirm it works this time — batch size 128 but we logged 256 samples with 2 GPUs.
Description
The ppo script hangs when logging to wandb in a multi-gpu setup. But, it works fine without wandb.
Potential diagnosis
It is caused when calling
log_stats
in here, which probably triggers some error here when the condition is true. It could be related to 481ef96293d9ecc68acc636287ce97489f1d23d4.Code to reproduce
I am testing it on a machine with two 3090 GPUs.
Packages:
Command:
accelerate launch --main_process_port "${RANDOMPORT}" --multi_gpu --num_processes 2 main.py
Code: