Closed aknvictor closed 3 weeks ago
Hi @aknvictor thanks for creating the issue. With NCCL timeouts it can be hard to pinpoint the cause. Looks like this is occurring at the end of an epoch? Also I'm curious what's in the scripts/2B_lora.yaml
config -- is it just a copied version of gemma/2B_lora.yaml
, or have you made any other customizations?
Yes, the error occurs at the end of an epoch. Yes, 2B_lora.yaml
is copy of the original gemma/2B_lora.yaml
config with only modifications to the file paths and batch_size.
@aknvictor not sure what type of GPU you're on, but if possible can you try to run on a single device instead? Something like tune run lora_finetune_single_device --config scripts/2B_lora_single_device.yaml
, where scripts/2B_lora_single_device.yaml
is an analogous copy of torchtune's corresponding single-device config gemma/2B_lora_single_device.yaml
.
I tried to repro on my end and I see the same error as in #1122, so wondering if that's the underlying cause here and the distributed run is masking the real source of the error.
Yes, I did run it on a single device (A100). It works fine after i skipped the erroneous key in the checkpoint save (as a temporary hack).
if key == 'lm_head.weight':
continue
Admittedly, the bug/issue may be broader than that (especially when the run is distributed)
It works fine after i skipped the erroneous key in the checkpoint save (as a temporary hack).
@aknvictor just to clarify, does skipping the key in the distributed case resolve the original timeout error? Or do you still see it even after removing that line?
The timeout error is still there when the run is distributed.
We fixed the Gemma checkpoint issue (Issue #1190). Could you try running your script again without the code below?
if key == 'lm_head.weight':
continue
Yes. I'm still getting the error
Hi @aknvictor sorry for the delay here. If you're still seeing the timeout error on distributed runs after pulling from latest main, would you be able to (a) provide more details of your environment (pip list, what hardware you're running on) and (b) help pinpoint where exactly the hang is occurring? I'm assuming it's somewhere in checkpoint save, maybe when trying to gather parameters from different GPUs, but not sure. (One hacky way to narrow it down for (b) is just add a call to torch.distributed.barrier()
and then raise an error immediately afterwards; then you can bisect where in the code the hang is occurring based on whether you get this error or not)
The issues has been resolved in latest main. Thanks!
I keep getting this error when I run with
CUDA_LAUNCH_BLOCKING=1; tune run --nproc_per_node 4 lora_finetune_distributed --config scripts/2B_lora.yaml
any thoughts what I might be doing wrong. I'm running the latest version (from github)