facebookresearch / fvcore

Collection of common code that's shared among different research projects in FAIR computer vision team.
Apache License 2.0
1.93k stars 226 forks source link

Conflict with torch.distributed? #145

Open Wuzimeng opened 5 months ago

Wuzimeng commented 5 months ago

Hello, I encounted an error when calling flop_count_table() in my distributed training code. The error message is as below. But I checked the input of function allgather() and didn't find anything unusual.

File "/xxx/anaconda3/envs/torch13/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2275, in all_gather work = default_pg.allgather([tensor_list], [tensor]) RuntimeError: unsupported input list type: Tensor[] ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 628698) of binary: /xxx/anaconda3/envs/torch13/bin/python

Here's a brief code which can regenerate my error by calling python -m torch.distributed.run --nproc_per_node=1 --master_port 10603 try.py

import torch
import torch.nn as nn

from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count_table

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        y = self.fc(x)
        concat_all_gather(y)
        return y.sum()

@torch.no_grad()
def concat_all_gather(tensor):
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
    output = torch.cat(tensors_gather, dim=0)
    return output    

torch.distributed.init_process_group(backend='nccl')
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

model = SimpleModel().cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)

flop = FlopCountAnalysis(model.module, torch.randn(100, 10).cuda())
print(flop_count_table(flop, max_depth=7, show_param_shapes=True))

torch.distributed.destroy_process_group()

Additionally, my environment is: Python 3.9.18, cuda-11.7, fvcore==0.1.5.post20221221, torch 1.13

Another confusing thing is, in the python3.8.18 & cuda-11.4 & torch 1.10 environment, the above doesn't result in an error.

philipwan commented 1 month ago

I'm having the same problem. Have you solved it? In my condition, it seems like there is a conflict between jit.trace module and dist.all_gather