Lyken17 / pytorch-OpCounter

Count the MACs / FLOPs of your PyTorch model.
MIT License
4.9k stars 528 forks source link

how to define the rule for 3rd party module? #172

Open moshicaixi opened 2 years ago

moshicaixi commented 2 years ago

Hi. thank you for your excellent work!

My research area is 3D point clouds, including shape classification and semantic segmentation. In my pytorch project, there are some 3rd party libs, such as ball query algorithm in PointNet++, which is customized CUDA function. In this situation, how could i define the rule for calculating the macs and params? I would appreciate it if you can give me some advice.

100

Lyken17 commented 2 years ago

For newly defined modules (no matter in CUDA or C or Python), there should be a corresponding python class wrapper. You can register counting functions with that wrapper.

As for reference, you can check how THOP counts for PyTorch modules

https://github.com/Lyken17/pytorch-OpCounter/blob/master/thop/profile.py#L21-L65