sovrasov / flops-counter.pytorch

Flops counter for convolutional networks in pytorch framework
MIT License
2.83k stars 306 forks source link

Support ViT from timm huggingface #124

Closed zhuoyan-xu closed 1 year ago

zhuoyan-xu commented 1 year ago

Hi,

Thank you for this great work! I am using this codebase for computing FLOPs of ViT from timm HuggingFace module. I am wondering whether you plan to add support of computation including models from HuggingFace. If so, I'd like to make a contribution here. By adding this counter_hook, the code is able to count all ViT models from timm.

import timm

def timm_attention_counter_hook(attention_module, input, output):
    flops = 0
    B, N, C = input[0].shape  # [Batch_size, Seq_len, Dimension]

    # QKV projection is already covered in MODULES_MAPPING

    # Q scaling
    flops += N * attention_module.head_dim * attention_module.num_heads

    # head flops
    head_flops = (
        (N * N * attention_module.head_dim)  # QK^T
        + (N * N) # softmax
        + (N * N * attention_module.head_dim) # AV
    )
    flops += head_flops * attention_module.num_heads

    # Final projection is already covered in MODULES_MAPPING

    flops *= B
    attention_module.__flops__ += int(flops)

MODULES_MAPPING.update({timm.models.vision_transformer.Attention: timm_attention_counter_hook})

https://github.com/sovrasov/flops-counter.pytorch/blob/316cda90fbabd42038647d346ce93f8d649a86b8/ptflops/pytorch_ops.py#L242-L292

sovrasov commented 1 year ago

Hi! That's a nice proposal, thanks! Could you make a PR? I'll think about adding some tests later