pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.69k stars 213 forks source link

Checkpoint saves failing for eager mode training #168

Closed chauhang closed 6 months ago

chauhang commented 8 months ago

There seems to be some tricky timeout issue during checkpoint saves. Failing for most runs for me on multiple machines,

Steps to reproduce:

  1. git clone and install torchtrain
  2. modify one of the config files and set checkpoint_folder = "./outputs"
  3. Launch training run: CONFIG_FILE=./train_configs/llama_1b.toml ./run_llama_train.sh

Fails with Error:

[rank0]:Thread 0x00007f07a6ffd640 (most recent call first):
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/threading.py", line 324 in wait
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/queue.py", line 180 in get
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/tensorboard/summary/writer/event_file_writer.py", line 269 in _run
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/tensorboard/summary/writer/event_file_writer.py", line 244 in run
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/threading.py", line 1038 in _bootstrap_inner
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/threading.py", line 995 in _bootstrap
[rank0]:
[rank0]:Thread 0x00007f13c65f5400 (most recent call first):
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 3230 in scatter
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 75 in wrapper
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2770 in scatter_object_list
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 75 in wrapper
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 133 in scatter_object
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 188 in reduce_scatter
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 274 in _save_state_dict
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 145 in save
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 427 in inner_func
[rank0]:  File "/home/gchauhan/meta/torchtrain/torchtrain/checkpoint.py", line 114 in save
[rank0]:  File "/home/gchauhan/meta/torchtrain/train.py", line 368 in main
[rank0]:  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347 in wrapper
[rank0]:  File "/home/gchauhan/meta/torchtrain/train.py", line 389 in <module>

Failures happening on Nvidia H100, A100, AMD MI250x. This trace is for AMD run. Full training log trace for 1b model training checkpoint save Flight recorder trace

Environment

python -m torch.utils.collect_env
PyTorch version: 2.4.0.dev20240326+rocm6.0
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 6.0.32830-d62f6a171

OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.4.1 20231218 (Red Hat 11.4.1-3)
Clang version: Could not collect
CMake version: version 3.26.5
Libc version: glibc-2.34

Python version: 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.4.3-0_fbk1_zion_755_ga25447393a1d-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Instinct MI250X / MI250 (gfx90a:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: 6.0.32830
MIOpen runtime version: 3.0.0
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   48 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          256
On-line CPU(s) list:             0-255
Vendor ID:                       AuthenticAMD
Model name:                      AMD EPYC 7713 64-Core Processor
CPU family:                      25
Model:                           1
Thread(s) per core:              2
Core(s) per socket:              64
Socket(s):                       2
Stepping:                        1
Frequency boost:                 enabled
CPU(s) scaling MHz:              100%
CPU max MHz:                     2000.0000
CPU min MHz:                     1500.0000
BogoMIPS:                        3992.39
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca
Virtualization:                  AMD-V
L1d cache:                       4 MiB (128 instances)
L1i cache:                       4 MiB (128 instances)
L2 cache:                        64 MiB (128 instances)
L3 cache:                        512 MiB (16 instances)
NUMA node(s):                    2
NUMA node0 CPU(s):               0-63,128-191
NUMA node1 CPU(s):               64-127,192-255
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Vulnerable, IBPB: conditional, IBRS_FW, STIBP: always-on, RSB filling, PBRSB-eIBRS: Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.0
[pip3] pytorch-triton-rocm==3.0.0+0a22a91d04
[pip3] torch==2.4.0.dev20240326+rocm6.0
[conda] numpy                     1.26.0                   pypi_0    pypi
[conda] pytorch-triton-rocm       3.0.0+0a22a91d04          pypi_0    pypi
[conda] torch                     2.4.0.dev20240326+rocm6.0          pypi_0    pypi
wanchaol commented 8 months ago

@wz337 mentioned that she tried it today and it works on her end. @wz337 could you work with @chauhang to resolve this?

wanchaol commented 8 months ago

I just tried this locally on the debug_model.toml, and it also seems working on my end 🤔 could there be setup differences?

wz337 commented 8 months ago

CONFIG_FILE=./train_configs/llama_1b.toml ./run_llama_train.sh

@chauhang Are you using the llama_7b.toml? or do you have a llama_1b.toml that is not checked in in main? Just want to make sure I have the exact same setup as you do.

wconstab commented 8 months ago

to me, the relevant lines of the log are

[rank0]:2024-03-26 19:44:12,171 - root - INFO - Saving a checkpoint at step 1000
[rank0]:[rank0]:[E326 19:44:37.537349911 ProcessGroupNCCL.cpp:1332] [PG 0 Rank 0] Received a global timeout from another rank and will start to dump the debug info. Last enqueued NCCL work: 57301, last completed NCCL work: 57301.
[rank0]

I suspect this issue is caused by a combination of (a) short timeout, and (b) some ranks are doing CPU work for checkpointing while other ranks already called a collective.

We need to first pin this down to a specific collective and identify where it comes from in checkpointing or outside checkpointing. If it's what I think it is, maybe the fix is to set a longer timeout before performing checkpointing, then set a short timeout again after checkpointing. I would rather not just land a change to increase train_timeout_seconds default value, without first understanding why the longer timeout is needed and whether we can change DCP so that a shorter timeout is compatible with it.

Also, @geeta a safe workaround should be to change the timeout flag in the .toml or in your command line args in the .sh.

--comm.train_timeout_seconds <sec> or

[comm]
train_timeout_seconds=<sec>
wz337 commented 8 months ago

train_timeout_seconds

@wconstab Thanks for looking into the issue. If it's what you suspected, I think changing dcp.save to dcp.async_save would potentially help this, as we would de-stages the state_dict on CPU, and then callssave in a separate thread.

https://github.com/pytorch/torchtrain/blob/main/torchtrain/checkpoint.py#L114

wconstab commented 8 months ago

I still think we need a design review for DCP with regard to timeouts.

Directionally, we want to have shorter timeouts when possible to get faster error signals.

We should decide, is it up to the user to estimate how much time DCP would need and adjust their timeout before calling DCP, or is there anything DCP can do to help this? It should be possible for DCP to issue its own collectives with a longer timeout than the default one, if DCP knows how long the timeouts should be. (and if DCP doesn't know how long the timeouts should be roughly, then how would a user know)

fegin commented 8 months ago

During this step, only rank0 is doing some reduction work of the plans. But I'm surprised it will be slow enough to cause the NCCL timeout. Verifying with a large timeout can help to identify the issue.

wz337 commented 6 months ago

Closing as we cannot repro this issue.