Thank you for this great work! I am using this codebase for computing FLOPs of ViT from timm HuggingFace module. I am wondering whether you plan to add support of computation including models from HuggingFace. If so, I'd like to make a contribution here. By adding this counter_hook, the code is able to count all ViT models from timm.
import timm
def timm_attention_counter_hook(attention_module, input, output):
flops = 0
B, N, C = input[0].shape # [Batch_size, Seq_len, Dimension]
# QKV projection is already covered in MODULES_MAPPING
# Q scaling
flops += N * attention_module.head_dim * attention_module.num_heads
# head flops
head_flops = (
(N * N * attention_module.head_dim) # QK^T
+ (N * N) # softmax
+ (N * N * attention_module.head_dim) # AV
)
flops += head_flops * attention_module.num_heads
# Final projection is already covered in MODULES_MAPPING
flops *= B
attention_module.__flops__ += int(flops)
MODULES_MAPPING.update({timm.models.vision_transformer.Attention: timm_attention_counter_hook})
Hi,
Thank you for this great work! I am using this codebase for computing FLOPs of ViT from timm HuggingFace module. I am wondering whether you plan to add support of computation including models from HuggingFace. If so, I'd like to make a contribution here. By adding this counter_hook, the code is able to count all ViT models from timm.
https://github.com/sovrasov/flops-counter.pytorch/blob/316cda90fbabd42038647d346ce93f8d649a86b8/ptflops/pytorch_ops.py#L242-L292