QUVA-Lab / e2cnn

E(2)-Equivariant CNNs Library for Pytorch
https://quva-lab.github.io/e2cnn/
Other
596 stars 75 forks source link

Counting FLOPs for e2cnn #52

Closed shim94kr closed 2 years ago

shim94kr commented 2 years ago

Hello,

Thank you for sharing this amazing work!

While playing it a little bit, I wonder how many flops the network consumes.

I found that the existing library (e.g., https://github.com/1adrianb/pytorch-estimate-flops) does not work for this as they don't support equivariant modules.

Is there any simple method for this one? If the counting for kernel construction parts would be hard, I just want to know how to translate the model into torch model as in test phase to count using the existing library.

Gabri95 commented 2 years ago

Hi @shim94kr ,

mhmm estimating the number of flops for a general equivariant model might be quite hard in general, as it will require ad hoc code.

However, if your model is sufficiently simple, you can just call the .export() method and turn it into a pure PyTorch model and, then, try to use that library over the converted model. This is also a more reliable estimation: at test/inference time, one does not need to construct or expand the kernels, which means that overhead should not be counted.

Hope this helps!

Best, Gabriele