Hello, I encounted an error when calling flop_count_table() in my distributed training code.
The error message is as below. But I checked the input of function allgather() and didn't find anything unusual.
File "/xxx/anaconda3/envs/torch13/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2275, in all_gather
work = default_pg.allgather([tensor_list], [tensor])
RuntimeError: unsupported input list type: Tensor[]
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 628698) of binary: /xxx/anaconda3/envs/torch13/bin/python
Here's a brief code which can regenerate my error by calling python -m torch.distributed.run --nproc_per_node=1 --master_port 10603 try.py
import torch
import torch.nn as nn
from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count_table
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
y = self.fc(x)
concat_all_gather(y)
return y.sum()
@torch.no_grad()
def concat_all_gather(tensor):
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
torch.distributed.init_process_group(backend='nccl')
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
model = SimpleModel().cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
flop = FlopCountAnalysis(model.module, torch.randn(100, 10).cuda())
print(flop_count_table(flop, max_depth=7, show_param_shapes=True))
torch.distributed.destroy_process_group()
Additionally, my environment is:
Python 3.9.18, cuda-11.7, fvcore==0.1.5.post20221221, torch 1.13
Another confusing thing is, in the python3.8.18 & cuda-11.4 & torch 1.10 environment, the above doesn't result in an error.
Hello, I encounted an error when calling
flop_count_table()
in my distributed training code. The error message is as below. But I checked the input of functionallgather()
and didn't find anything unusual.Here's a brief code which can regenerate my error by calling
python -m torch.distributed.run --nproc_per_node=1 --master_port 10603 try.py
Additionally, my environment is:
Python 3.9.18
,cuda-11.7
,fvcore==0.1.5.post20221221
,torch 1.13
Another confusing thing is, in the
python3.8.18
&cuda-11.4
&torch 1.10
environment, the above doesn't result in an error.