pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.14k stars 22.42k forks source link

FSDP requires global device context #112658

Open nairbv opened 11 months ago

nairbv commented 11 months ago

🐛 Describe the bug

The only way to call an FSDP model (e.g. fsdp_model(inputs)) seems to be if torch.cuda.current_device() returns the rank/id of the current process/device, regardless of what device the model is on and regardless of what device_id is passed to the FSDP constructor.

(i.e. by either first setting torch.cuda.set_device(device) (which is "discouraged") or with a context manager like with torch.cuda.device(local_rank):)

Without this kind of device context there will be some error of the form “Expects tensor to be on the compute device cuda:2” or “An FSDP-managed module unexpectedly has parameters on 
.. Make sure to move the module to 
 before training” or "Inconsistent compute device and device_id on rank" or ...

I suspect this is just a bug in the set of asserts across _flat_param.py, _runtime_utils.py, and/or _init_utils.py. If the requirement that torch.cuda.current_device() return the current rank is intentional though, I think we should call it out more explicitly in the docs and tutorials (and error messages).

Versions

$ srun python collect_env.py
Collecting environment information...
PyTorch version: 2.1.0
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-1041-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 535.54.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   46 bits physical, 48 bits virtual
CPU(s):                          96
On-line CPU(s) list:             0-95
Thread(s) per core:              2
Core(s) per socket:              24
Socket(s):                       2
NUMA node(s):                    2
Vendor ID:                       GenuineIntel
CPU family:                      6
Model:                           85
Model name:                      Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping:                        7
CPU MHz:                         3000.000
BogoMIPS:                        6000.00
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       1.5 MiB
L1i cache:                       1.5 MiB
L2 cache:                        48 MiB
L3 cache:                        71.5 MiB
NUMA node0 CPU(s):               0-23,48-71
NUMA node1 CPU(s):               24-47,72-95
Vulnerability Itlb multihit:     KVM: Mitigation: VMX unsupported
Vulnerability L1tf:              Mitigation; PTE Inversion
Vulnerability Mds:               Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown:          Mitigation; PTI
Vulnerability Mmio stale data:   Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:          Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.0
[pip3] torch==2.1.0
[pip3] torchaudio==2.1.0
[pip3] torchvision==0.16.0
[pip3] triton==2.1.0
[conda] blas                      1.0                         mkl
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46343
[conda] mkl-service               2.4.0           py311h5eee18b_1
[conda] mkl_fft                   1.3.8           py311h5eee18b_0
[conda] mkl_random                1.2.4           py311hdb19cb5_0
[conda] numpy                     1.26.0          py311h08b1b3b_0
[conda] numpy-base                1.26.0          py311hf175353_0
[conda] pytorch                   2.1.0           py3.11_cuda12.1_cudnn8.9.2_0    pytorch
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                2.1.0               py311_cu121    pytorch
[conda] torchtriton               2.1.0                     py311    pytorch
[conda] torchvision               0.16.0              py311_cu121    pytorch

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin

kwen2501 commented 11 months ago

Can you provide a reproducer? Thanks!

nairbv commented 11 months ago
import torch
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch import distributed
import os

local_rank = int(os.getenv("LOCAL_RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
distributed.init_process_group(backend='nccl', rank=local_rank, world_size=world_size)
device = torch.device('cuda', local_rank)

print(local_rank, world_size, device)

model = nn.Linear(100,100)
model.to(device)
model = FSDP(model, device_id=local_rank)
#torch.cuda.set_device(device)
output = model(torch.randn((1, 100), device=device))

run with: srun -N 1 --gres=gpu:2 torchrun --nproc_per_node=2 test.py fails with:

Expects tensor to be on the compute device cuda:1
  File "/home/bvaughan/repos/newfms/test.py", line 18, in <module>
    output = model(torch.randn((1, 100)))
...
...followed by...
AssertionError: Expects tensor to be on the compute device cuda:1
[E ProcessGroupNCCL.cpp:915] [Rank 1] NCCL watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at /opt/conda/conda-bld/pytorch_1695392035891/work/c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f3da91b5617 in /home/bvaughan/miniconda3/envs/pt21/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f3da917098d in /home/bvaughan/miniconda3/envs/pt21/lib/python3.11/site-packages/to....

Uncommenting the torch.cuda.set_device(device) fixes it, but I think there should be some way to use FSDP without setting context like that.

I've tried lots of different variants of the above, which lead to a variety of error messages, but the idea is the same. set_device (or with torch.cuda.device(rank) seems to be the only approaches that are working.

awgu commented 11 months ago

For some not-yet-known reason, a slice (or even a view) on padded_unsharded_flat_param results in a tensor on cuda:0 even if padded_unsharded_flat_param was on cuda:1: https://github.com/pytorch/pytorch/blob/5da9abfec211f77a5803ac6a2af767d80f088bb3/torch/distributed/fsdp/_flat_param.py#L1364-L1366 I encountered this before: https://github.com/pytorch/pytorch/issues/91661.

I cannot reproduce in a simple script though, so I would need to investigate further what differs in FSDP that cases this to happen.

awgu commented 11 months ago

Update: I was able to produce a smaller repro: https://github.com/pytorch/pytorch/issues/113300

nairbv commented 10 months ago

@awgu I see the Storage.resize_() bug is fixed now. Did that also fix this issue?

awgu commented 9 months ago

@awgu I see the Storage.resize_() bug is fixed now. Did that also fix this issue?

Sorry for the delay. I have not had a chance to see if that fixed the assumption of having the device set through all of FSDP.