Closed sangkeun00 closed 7 months ago
Here is the minimal example to reproduce this bug:
import os
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
import torch.distributed as dist
import torch.nn as nn
import analog
def init_distributed_mode():
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
torch.distributed.init_process_group(backend='nccl')
# Define the model
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
return x
def main():
init_distributed_mode()
run = analog.init("fsdp")
model = MyModel()
fsdp_model = FSDP(model.cuda(), auto_wrap_policy=ModuleWrapPolicy({torch.nn.Conv2d}))
# fsdp_model = FSDP(model.cuda())
analog.watch(fsdp_model)
run.setup({"log": "grad"})
with run(data_id=["1", "2"]):
x = torch.randn(2, 3, 2, 2).cuda()
out = fsdp_model(x)
loss = out.pow(2).sum()
loss.backward()
log = run.get_log()
if __name__ == "__main__":
main()
The above error only occurs when using auto_wrap_policy
in FSDP.
Interestingly, this error occurred in the below line:
where model.state_dict()
is called, similar to the issue in the original PyTorch repo. When I replaced state_dict
with named_modules
, the error disappeared.
The above error occurs when using HF Trainer + FSDP + autowrap policy. Not sure if this is related to this Issue in the original PyTorch repo.