facebookresearch / fvcore

Collection of common code that's shared among different research projects in FAIR computer vision team.
Apache License 2.0
1.93k stars 226 forks source link

Wrong FLOPs count if some submodules of the model are wrapped by DataParallel #109

Open luowyang opened 2 years ago

luowyang commented 2 years ago

Suppose there is a model which has some submodules wrapped by DataParallel (DP), e.g.:

class Model(nn.Module):
    def __init__(self, use_dp):
        super(Model, self).__init__()
        self.use_dp = use_dp
        self.sub = nn.Sequential(
            nn.Linear(14*14, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
        self.other = nn.Sequential(
            nn.Linear(14 * 14, 128),
            nn.ReLU(),
            nn.Linear(128, 5)
        )
        if self.use_dp:
            self.sub = nn.DataParallel(self.sub, device_ids=[0, 1])

    def forward(self, x):
        y1 = self.sub(x)  # (N, 10)
        y2 = self.other(x)  # (N, 5)
        return torch.cat([y1, y2], dim=-1)  # (N, 15)

def main():
    # build inputs
    inputs = torch.randn(64, 14*14)  # (64, 14*14)

    # w/o DP
    model = Model(False)
    print(flop_count_table(FlopCountAnalysis(model, (inputs,))))

    # with DP
    inputs = inputs.cuda(0)
    model = Model(True)
    model.cuda(0)
    print(flop_count_table(FlopCountAnalysis(model, (inputs,))))

if __name__ == '__main__':
    main()

The above code use FlopCountAnalysis from detectron2 .

The printed results are:

| module            | #parameters or shape   | #flops   |
|:------------------|:-----------------------|:---------|
| model             | 78.863K                | 5.022M   |
|  sub              |  53.002K               |  3.375M  |
|   sub.0           |   50.432K              |   3.211M |
|    sub.0.weight   |    (256, 196)          |          |
|    sub.0.bias     |    (256,)              |          |
|   sub.2           |   2.57K                |   0.164M |
|    sub.2.weight   |    (10, 256)           |          |
|    sub.2.bias     |    (10,)               |          |
|  other            |  25.861K               |  1.647M  |
|   other.0         |   25.216K              |   1.606M |
|    other.0.weight |    (128, 196)          |          |
|    other.0.bias   |    (128,)              |          |
|   other.2         |   0.645K               |   40.96K |
|    other.2.weight |    (5, 128)            |          |
|    other.2.bias   |    (5,)                |          |
| module                 | #parameters or shape   | #flops   |
|:-----------------------|:-----------------------|:---------|
| model                  | 78.863K                | 1.647M   |
|  sub.module            |  53.002K               |  0       |
|   sub.module.0         |   50.432K              |          |
|    sub.module.0.weight |    (256, 196)          |          |
|    sub.module.0.bias   |    (256,)              |          |
|   sub.module.2         |   2.57K                |          |
|    sub.module.2.weight |    (10, 256)           |          |
|    sub.module.2.bias   |    (10,)               |          |
|  other                 |  25.861K               |  1.647M  |
|   other.0              |   25.216K              |   1.606M |
|    other.0.weight      |    (128, 196)          |          |
|    other.0.bias        |    (128,)              |          |
|   other.2              |   0.645K               |   40.96K |
|    other.2.weight      |    (5, 128)            |          |
|    other.2.bias        |    (5,)                |          |

And for fvcore's FlopCountAnalysis, an error is raised:

| module            | #parameters or shape   | #flops   |
|:------------------|:-----------------------|:---------|
| model             | 78.863K                | 5.022M   |
|  sub              |  53.002K               |  3.375M  |
|   sub.0           |   50.432K              |   3.211M |
|    sub.0.weight   |    (256, 196)          |          |
|    sub.0.bias     |    (256,)              |          |
|   sub.2           |   2.57K                |   0.164M |
|    sub.2.weight   |    (10, 256)           |          |
|    sub.2.bias     |    (10,)               |          |
|  other            |  25.861K               |  1.647M  |
|   other.0         |   25.216K              |   1.606M |
|    other.0.weight |    (128, 196)          |          |
|    other.0.bias   |    (128,)              |          |
|   other.2         |   0.645K               |   40.96K |
|    other.2.weight |    (5, 128)            |          |
|    other.2.bias   |    (5,)                |          |

Traceback (most recent call last):
  File "proof/flops.py", line 48, in <module>
    main()
  File "proof/flops.py", line 44, in main
    print(flop_count_table(FlopCountAnalysis(model, (inputs,))))
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/fvcore/nn/print_model_statistics.py", line 632, in flop_count_table
    stats = {params_header: params, flops_header: flops.by_module()}
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/fvcore/nn/jit_analysis.py", line 291, in by_module
    stats = self._analyze()
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/fvcore/nn/jit_analysis.py", line 551, in _analyze
    graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases)
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/fvcore/nn/jit_analysis.py", line 176, in _get_scoped_trace_graph
    graph, _ = _get_trace_graph(module, inputs)
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/torch/jit/_trace.py", line 132, in forward
    self._force_outplace,
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1128, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "proof/flops.py", line 25, in forward
    y1 = self.sub(x)  # (N, 10)
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1128, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 169, in forward
    return self.gather(outputs, self.output_device)
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 181, in gather
    return gather(outputs, output_device, dim=self.dim)
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 78, in gather
    res = gather_map(outputs)
  File "/home/liuyufan/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map
    return Gather.apply(target_device, dim, *outputs)
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
-0.0631 -0.0924 -0.2030 -0.0330  0.1524  0.2239 -0.0218  0.4040  0.1781 -0.2022
-0.0835 -0.3393 -0.4577 -0.1289 -0.2029  0.1192 -0.3940  0.1531 -0.1451 -0.0058
 0.2260 -0.0601  0.0169 -0.1429 -0.0189  0.1719 -0.2006 -0.1745  0.5444  0.0020
 0.0925 -0.1080  0.0259 -0.2739  0.0398 -0.1377 -0.0323  0.0944  0.1571 -0.0365
 0.3726 -0.0209 -0.1853 -0.0827 -0.1771  0.0781 -0.4529  0.2224 -0.1099  0.0106
-0.0509 -0.1164 -0.0313  0.3320 -0.1911  0.1554  0.1027 -0.0530  0.1241 -0.0692
-0.4958 -0.2018 -0.3424 -0.0089  0.1381  0.2377 -0.6118  0.1509  0.2811 -0.0853
-0.0344  0.1100 -0.5990  0.1684 -0.3558  0.3262 -0.2896  0.1721 -0.0036 -0.1114
-0.0625 -0.2857 -0.3084  0.0077 -0.0463 -0.3042 -0.2004 -0.2280 -0.0982  0.0031
-0.0192 -0.3513 -0.1022 -0.2702  0.0590 -0.3046 -0.1896  0.1920  0.2198 -0.0206
-0.1416 -0.1970  0.1886 -0.0509 -0.2409  0.0506 -0.1899 -0.1524 -0.1230  0.0473
 0.0438 -0.2173 -0.1229  0.1212  0.0314  0.1478 -0.2239 -0.0321 -0.0447 -0.0551
-0.5178 -0.3879 -0.1093 -0.2340 -0.2586  0.0878 -0.1426  0.1673 -0.1966 -0.3724
 0.1796 -0.1356 -0.1710 -0.2462 -0.1699  0.2071 -0.1426 -0.2134  0.0915 -0.0392
-0.0200 -0.2622 -0.1702 -0.1576 -0.1449 -0.0294 -0.1793  0.1020 -0.1780  0.2384
 0.2026 -0.1594  0.0261 -0.2954  0.0951 -0.0760 -0.4863 -0.0666  0.1355  0.0396
-0.0173 -0.4742 -0.1124 -0.3442 -0.5237 -0.1680 -0.3650  0.1321  0.2550 -0.1832
 0.1103 -0.2563 -0.1751 -0.3580 -0.1326  0.2466 -0.0704  0.0189  0.0391 -0.0141
 0.0047 -0.1999 -0.0985 -0.3539 -0.3202 -0.0952 -0.2767 -0.1075 -0.0919  0.3078
 0.1033 -0.4031 -0.1035 -0.0727  0.0155 -0.0575 -0.1450  0.0333 -0.0815  0.0969
 0.1358 -0.1758 -0.0728  0.0637  0.3167 -0.2359 -0.0732 -0.4764 -0.0577  0.1946
-0.0813 -0.4349  0.0939 -0.0659 -0.1964 -0.2355 -0.3982 -0.2524 -0.0686  0.2043
-0.2809 -0.0360 -0.2679  0.0990 -0.0930  0.3680 -0.1424 -0.1054  0.0884  0.1183
 0.0839 -0.0831 -0.1549  0.0354 -0.0027  0.3032 -0.2437 -0.1630  0.2193 -0.0362
 0.0969  0.0473 -0.3632  0.0263 -0.1072  0.2200 -0.0915  0.1721 -0.1492 -0.1488
 0.0083 -0.1628 -0.0357 -0.2727 -0.3732  0.0203 -0.2248 -0.0532  0.0092 -0.0459
 0.1578 -0.3242  0.0616 -0.6269  0.0259  0.2905 -0.2921 -0.1561  0.2467  0.2376
 0.1219 -0.2721  0.1395 -0.0047  0.0558 -0.0447 -0.3743  0.1283  0.1492  0.0486
 0.1464 -0.2029  0.0043 -0.1902  0.0886 -0.0617 -0.2036  0.1473 -0.0951  0.2926
 0.3260 -0.1513 -0.1700 -0.1887 -0.0139 -0.5032  0.0469 -0.0312 -0.1988  0.1037
-0.0864 -0.4793  0.1882 -0.2387 -0.2558 -0.0244 -0.0947 -0.1375  0.1417 -0.1359
 0.2344 -0.3205 -0.3324 -0.3037 -0.1324 -0.2132 -0.3083  0.0121  0.2760 -0.1389
[ torch.cuda.FloatTensor{32,10} ]

In either case, the FLOPs count for the partially wrapped model is not correctly calculated. So the question is, is it feasible to count the FLOPs of the DP-wrapped model by something like hooks?