Open lessw2020 opened 1 year ago
FWIW I couldn't repro this by just compiling the LlamaRotaryEmbedding
from HuggingFace on a single GPU without FSDP
I think the core issue is the mixed precision aspect of FSDP and how compile is interacting with it. Per IBM, if you move your weights to pure BF16 (so no mixed precision) then this issue goes away (though they are then reporting it errors out with a stride mismatch..but we'll get to that after this is resolved).
Horace: There's another issue with DDP. Is there anyone signed up to own FSDP + torch.compile / DDP + torch.compile? This is basically @voznesenskym
Is there anyone signed up to own FSDP + torch.compile / DDP + torch.compile
I'm currently working on this.
I am working on compile + FSDP. Tracked w/ meta internal posts.
Sounds good @voznesenskym we could sync later on this as needed.
We are working on compile + FSDP which is preferred over graph-break FSDP. We aim to have it ready at "prototype" release stage by end of H1.
Ran into the same issue, will this be fixed in an incoming release? And, wondering if there is any workaround fix? Thanks.
🐛 Describe the bug
Running Torch.compile with Llama7B and FSDP mixed precision, results in assert during first forward pass of training: (you can repro by going to https://github.com/lessw2020/llama-recipes/tree/rotary_embeddings and run "bash run.sh")
from this section (full trace below):
Effectively there is a type mismatch but at least in adding some debugging to the Rotary cache and the incoming tensors, everything is all fp32.
Here's the full stack trace:
Versions
Collecting environment information... PyTorch version: 2.1.0.dev20230825+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 Clang version: Could not collect CMake version: version 3.26.4 Libc version: glibc-2.31
Python version: 3.9.12 (main, Apr 5 2022, 06:56:58) [GCC 7.5.0] (64-bit runtime) Python platform: Linux-5.15.0-1038-aws-x86_64-with-glibc2.31 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB GPU 1: NVIDIA A100-SXM4-40GB GPU 2: NVIDIA A100-SXM4-40GB GPU 3: NVIDIA A100-SXM4-40GB GPU 4: NVIDIA A100-SXM4-40GB GPU 5: NVIDIA A100-SXM4-40GB GPU 6: NVIDIA A100-SXM4-40GB GPU 7: NVIDIA A100-SXM4-40GB
Nvidia driver version: 525.85.12 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 46 bits physical, 48 bits virtual CPU(s): 96 On-line CPU(s) list: 0-95 Thread(s) per core: 2 Core(s) per socket: 24 Socket(s): 2 NUMA node(s): 2 Vendor ID: GenuineIntel CPU family: 6 Model: 85 Model name: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz Stepping: 7 CPU MHz: 1250.736 BogoMIPS: 5999.99 Hypervisor vendor: KVM Virtualization type: full L1d cache: 1.5 MiB L1i cache: 1.5 MiB L2 cache: 48 MiB L3 cache: 71.5 MiB NUMA node0 CPU(s): 0-23,48-71 NUMA node1 CPU(s): 24-47,72-95 Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported Vulnerability L1tf: Mitigation; PTE Inversion Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Meltdown: Mitigation; PTI Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Retbleed: Vulnerable Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
Versions of relevant libraries: [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.24.1 [pip3] pytorch-triton==2.1.0+e6216047b8 [pip3] st-moe-pytorch==0.0.22 [pip3] torch==2.1.0.dev20230825+cu121 [pip3] torchaudio==2.1.0.dev20230825+cu121 [pip3] torchinfo==1.8.0 [pip3] torchvision==0.16.0.dev20230825+cu121 [pip3] vit-pytorch==1.4.1 [conda] numpy 1.24.1 pypi_0 pypi [conda] pytorch-triton 2.1.0+e6216047b8 pypi_0 pypi [conda] st-moe-pytorch 0.0.22 pypi_0 pypi [conda] torch 2.1.0.dev20230825+cu121 pypi_0 pypi [conda] torchaudio 2.1.0.dev20230825+cu121 pypi_0 pypi [conda] torchinfo 1.8.0 pypi_0 pypi [conda] torchvision 0.16.0.dev20230825+cu121 pypi_0 pypi [conda] vit-pytorch 1.4.1 pypi_0 pypi
cc @ezyang @gchanan @zou3519 @kadeng @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @msaroufim @bdhirsh @anijain2305 @kiukchung @d4l3k @lucasllc