rayleizhu / BiFormer

[CVPR 2023] Official code release of our paper "BiFormer: Vision Transformer with Bi-Level Routing Attention"
https://arxiv.org/abs/2303.08810
MIT License
461 stars 36 forks source link

计算量统计 #33

Closed 1299361191 closed 10 months ago

1299361191 commented 10 months ago

作者你好,想问下模型计算量是如何统计的,是用库统计的吗?

rayleizhu commented 10 months ago

fvcore是非常可靠的工具。它是从gpu kernel级别统计计算量的, 因此只要不是你自己定制了某些kernel,它都可以记录:

https://github.com/rayleizhu/BiFormer/blob/1697bbbeafb8680524898f1dcaac10defd0604be/main.py#L27 https://github.com/rayleizhu/BiFormer/blob/1697bbbeafb8680524898f1dcaac10defd0604be/main.py#L281

1299361191 commented 10 months ago

谢谢!

1299361191 commented 10 months ago

想问一下,我运行这个代码,显示一些warnings,不知道是否影响到了计算量的统计

    model = biformer_small().cuda()
    x = torch.randn(1,3,224,224).cuda()
    flops = FlopCountAnalysis(model, x)
    # print(flop_count_table(flops))
    print(flops.total())

输出:

Unsupported operator aten::add_ encountered 6 time(s)
Unsupported operator aten::gelu encountered 31 time(s)
Unsupported operator aten::add encountered 120 time(s)
Unsupported operator aten::mul encountered 1048 time(s)
Unsupported operator aten::mul_ encountered 242 time(s)
Unsupported operator aten::mean encountered 53 time(s)
Unsupported operator aten::topk encountered 26 time(s)
Unsupported operator aten::softmax encountered 30 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
stages.0.0.attn.router.routing_act, stages.0.1.attn.router.routing_act, stages.0.2.attn.router.routing_act, stages.0.3.attn.router.routing_act, stages.1.0.attn.router.routing_act, stages.1.1.attn.router.routing_act, stages.1.2.attn.router.routing_act, stages.1.3.attn.router.routing_act, stages.2.0.attn.router.routing_act, stages.2.1.attn.router.routing_act, stages.2.10.attn.router.routing_act, stages.2.11.attn.router.routing_act, stages.2.12.attn.router.routing_act, stages.2.13.attn.router.routing_act, stages.2.14.attn.router.routing_act, stages.2.15.attn.router.routing_act, stages.2.16.attn.router.routing_act, stages.2.17.attn.router.routing_act, stages.2.2.attn.router.routing_act, stages.2.3.attn.router.routing_act, stages.2.4.attn.router.routing_act, stages.2.5.attn.router.routing_act, stages.2.6.attn.router.routing_act, stages.2.7.attn.router.routing_act, stages.2.8.attn.router.routing_act, stages.2.9.attn.router.routing_act
4489721088
rayleizhu commented 10 months ago

输出:


Unsupported operator aten::add_ encountered 6 time(s)
Unsupported operator aten::gelu encountered 31 time(s)
Unsupported operator aten::add encountered 120 time(s)
Unsupported operator aten::mul encountered 1048 time(s)
Unsupported operator aten::mul_ encountered 242 time(s)
Unsupported operator aten::mean encountered 53 time(s)
Unsupported operator aten::topk encountered 26 time(s)
Unsupported operator aten::softmax encountered 30 time(s)
  1. 我不确定Unsupported operator是否意味着fvcore不认识这些算子所以无法统计。如果是这种情况,是会漏掉的。但是也有可能更底层的算子捕捉了。
  2. 这些是常见的kernel,可能是因为pytorch的版本更新时有些kernel被重命名了,导致fvcore识别不了。你可以安装试试environment.yaml中的pytorch版本,我好像没有看到这个warning。

The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module. stages.0.0.attn.router.routing_act, stages.0.1.attn.router.routing_act, stages.0.2.attn.router.routing_act, stages.0.3.attn.router.routing_act, stages.1.0.attn.router.routing_act, stages.1.1.attn.router.routing_act, stages.1.2.attn.router.routing_act, stages.1.3.attn.router.routing_act, stages.2.0.attn.router.routing_act, stages.2.1.attn.router.routing_act, stages.2.10.attn.router.routing_act, stages.2.11.attn.router.routing_act, stages.2.12.attn.router.routing_act, stages.2.13.attn.router.routing_act, stages.2.14.attn.router.routing_act, stages.2.15.attn.router.routing_act, stages.2.16.attn.router.routing_act, stages.2.17.attn.router.routing_act, stages.2.2.attn.router.routing_act, stages.2.3.attn.router.routing_act, stages.2.4.attn.router.routing_act, stages.2.5.attn.router.routing_act, stages.2.6.attn.router.routing_act, stages.2.7.attn.router.routing_act, stages.2.8.attn.router.routing_act, stages.2.9.attn.router.routing_act 4489721088

这一部分可以确定没有影响。是因为我显式调用了forward()函数。

1299361191 commented 10 months ago

感谢您的解答!