Closed chauhang closed 6 months ago
@wz337 mentioned that she tried it today and it works on her end. @wz337 could you work with @chauhang to resolve this?
I just tried this locally on the debug_model.toml, and it also seems working on my end 🤔 could there be setup differences?
CONFIG_FILE=./train_configs/llama_1b.toml ./run_llama_train.sh
@chauhang Are you using the llama_7b.toml? or do you have a llama_1b.toml that is not checked in in main? Just want to make sure I have the exact same setup as you do.
to me, the relevant lines of the log are
[rank0]:2024-03-26 19:44:12,171 - root - INFO - Saving a checkpoint at step 1000
[rank0]:[rank0]:[E326 19:44:37.537349911 ProcessGroupNCCL.cpp:1332] [PG 0 Rank 0] Received a global timeout from another rank and will start to dump the debug info. Last enqueued NCCL work: 57301, last completed NCCL work: 57301.
[rank0]
I suspect this issue is caused by a combination of (a) short timeout, and (b) some ranks are doing CPU work for checkpointing while other ranks already called a collective.
We need to first pin this down to a specific collective and identify where it comes from in checkpointing or outside checkpointing. If it's what I think it is, maybe the fix is to set a longer timeout before performing checkpointing, then set a short timeout again after checkpointing. I would rather not just land a change to increase train_timeout_seconds default value, without first understanding why the longer timeout is needed and whether we can change DCP so that a shorter timeout is compatible with it.
Also, @geeta a safe workaround should be to change the timeout flag in the .toml or in your command line args in the .sh.
--comm.train_timeout_seconds <sec>
or
[comm]
train_timeout_seconds=<sec>
train_timeout_seconds
@wconstab Thanks for looking into the issue. If it's what you suspected, I think changing dcp.save
to dcp.async_save
would potentially help this, as we would de-stages the state_dict on CPU, and then callssave
in a separate thread.
https://github.com/pytorch/torchtrain/blob/main/torchtrain/checkpoint.py#L114
I still think we need a design review for DCP with regard to timeouts.
Directionally, we want to have shorter timeouts when possible to get faster error signals.
We should decide, is it up to the user to estimate how much time DCP would need and adjust their timeout before calling DCP, or is there anything DCP can do to help this? It should be possible for DCP to issue its own collectives with a longer timeout than the default one, if DCP knows how long the timeouts should be. (and if DCP doesn't know how long the timeouts should be roughly, then how would a user know)
During this step, only rank0 is doing some reduction work of the plans. But I'm surprised it will be slow enough to cause the NCCL timeout. Verifying with a large timeout can help to identify the issue.
Closing as we cannot repro this issue.
There seems to be some tricky timeout issue during checkpoint saves. Failing for most runs for me on multiple machines,
Steps to reproduce:
checkpoint_folder = "./outputs"
CONFIG_FILE=./train_configs/llama_1b.toml ./run_llama_train.sh
Fails with Error:
Failures happening on Nvidia H100, A100, AMD MI250x. This trace is for AMD run. Full training log trace for 1b model training checkpoint save Flight recorder trace
Environment