facebookresearch / fvcore

Collection of common code that's shared among different research projects in FAIR computer vision team.
Apache License 2.0
1.93k stars 226 forks source link

Error while calculating flops of transformer model. #129

Closed NeelKanwal closed 1 year ago

NeelKanwal commented 1 year ago

Hi, I am trying to calculate flops using different libraries for CNN and ViT and apart from getting different numbers, I am encountering some unsupported messages like these.

I am using a ViT from timm library which is based on pytorch.

model = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes =2) flops = FlopCountAnalysis(model, torch.rand(1, 3, 224, 224)) flops.total()

Output is : Unsupported operator aten::add encountered 25 time(s) Unsupported operator aten::div encountered 12 time(s) Unsupported operator aten::mul encountered 12 time(s) Unsupported operator aten::softmax encountered 12 time(s) Unsupported operator aten::gelu encountered 12 time(s)

FLOPs 1258219584

from using thops library flops, params = profile(model, inputs=(torch.randn(1, 3, 224, 224),), verbose=False)

answer is :

flops 1078442112

ppwwyyxx commented 1 year ago

Output is : Unsupported operator

This is expected. These ops are often negligible and don't have a corresponding flop counter implemented. If you want to count them, you can register a counter for them manually.

from using thops library flops, params = profile(model, inputs=(torch.randn(1, 3, 224, 224),), verbose=False) answer is : flops 1078442112

Most likely the thops library is wrong.

NeelKanwal commented 1 year ago

Thanks,

Can you please let me know how to find operations in each of these layers?

Most of other libraries raise error or unsupported message.

ppwwyyxx commented 1 year ago

I may not fully understand your question. As https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md#our-work says, flops.by_module_and_operator() shows operation flops in each layer.

NeelKanwal commented 1 year ago

Sorry not making it clear.

I mean, I would like to write a function which can manually add operations for unsupported layers; but I am not able get how can I estimate the number of ops in aten::add, aten::softmax etc.

I think, flops.by_module_and_operator() would also fail for unsupported layers.

ppwwyyxx commented 1 year ago
from fvcore.nn.jit_handles import get_shape

def add_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
    return np.prod(get_shape(outputs[0]))
flops = FlopCounter(model, inputs).set_op_handle("aten::add", add_flop_jit)

https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.nn.ActivationCountAnalysis.set_op_handle has the documentation.

NeelKanwal commented 1 year ago

thanks.

It looks like add_flop_jit will only rely on the output of module/layer. Would it be correct computation if the input dimensions are ignored?

ppwwyyxx commented 1 year ago

Some operation may use input shapes. For the add operation, the above add_flop_jit is correct.

chengyangfu commented 1 year ago

Thanks @ppwwyyxx for the explanation! I am going to close this issue.

gokul-uf commented 1 year ago

Hi @ppwwyyxx

This is expected. These ops are often negligible and don't have a corresponding flop counter implemented. If you want to count them, you can register a counter for them manually.

I am counting FLOP for a ViT L16 and the difference based on fvcore and this script (based off https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505/1) is very different.

Since transformer layers are widely used, it would help improve the flop counter's adoption if it supported the above ops.