Open nairbv opened 11 months ago
Can you provide a reproducer? Thanks!
import torch
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch import distributed
import os
local_rank = int(os.getenv("LOCAL_RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
distributed.init_process_group(backend='nccl', rank=local_rank, world_size=world_size)
device = torch.device('cuda', local_rank)
print(local_rank, world_size, device)
model = nn.Linear(100,100)
model.to(device)
model = FSDP(model, device_id=local_rank)
#torch.cuda.set_device(device)
output = model(torch.randn((1, 100), device=device))
run with:
srun -N 1 --gres=gpu:2 torchrun --nproc_per_node=2 test.py
fails with:
Expects tensor to be on the compute device cuda:1
File "/home/bvaughan/repos/newfms/test.py", line 18, in <module>
output = model(torch.randn((1, 100)))
...
...followed by...
AssertionError: Expects tensor to be on the compute device cuda:1
[E ProcessGroupNCCL.cpp:915] [Rank 1] NCCL watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at /opt/conda/conda-bld/pytorch_1695392035891/work/c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f3da91b5617 in /home/bvaughan/miniconda3/envs/pt21/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f3da917098d in /home/bvaughan/miniconda3/envs/pt21/lib/python3.11/site-packages/to....
Uncommenting the torch.cuda.set_device(device)
fixes it, but I think there should be some way to use FSDP without setting context like that.
I've tried lots of different variants of the above, which lead to a variety of error messages, but the idea is the same. set_device
(or with torch.cuda.device(rank)
seems to be the only approaches that are working.
For some not-yet-known reason, a slice (or even a view) on padded_unsharded_flat_param
results in a tensor on cuda:0
even if padded_unsharded_flat_param
was on cuda:1
:
https://github.com/pytorch/pytorch/blob/5da9abfec211f77a5803ac6a2af767d80f088bb3/torch/distributed/fsdp/_flat_param.py#L1364-L1366
I encountered this before: https://github.com/pytorch/pytorch/issues/91661.
I cannot reproduce in a simple script though, so I would need to investigate further what differs in FSDP that cases this to happen.
Update: I was able to produce a smaller repro: https://github.com/pytorch/pytorch/issues/113300
@awgu I see the Storage.resize_()
bug is fixed now. Did that also fix this issue?
@awgu I see the
Storage.resize_()
bug is fixed now. Did that also fix this issue?
Sorry for the delay. I have not had a chance to see if that fixed the assumption of having the device set through all of FSDP.
đ Describe the bug
The only way to call an FSDP model (e.g.
fsdp_model(inputs)
) seems to be iftorch.cuda.current_device()
returns the rank/id of the current process/device, regardless of what device the model is on and regardless of what device_id is passed to the FSDP constructor.(i.e. by either first setting
torch.cuda.set_device(device)
(which is "discouraged") or with a context manager likewith torch.cuda.device(local_rank):
)Without this kind of device context there will be some error of the form âExpects tensor to be on the compute device cuda:2â or âAn FSDP-managed module unexpectedly has parameters on âŠ.. Make sure to move the module to ⊠before trainingâ or "Inconsistent compute device and
device_id
on rank" or ...I suspect this is just a bug in the set of asserts across _flat_param.py, _runtime_utils.py, and/or _init_utils.py. If the requirement that
torch.cuda.current_device()
return the current rank is intentional though, I think we should call it out more explicitly in the docs and tutorials (and error messages).Versions
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin