Lyken17 / pytorch-OpCounter

Count the MACs / FLOPs of your PyTorch model.
MIT License
4.9k stars 528 forks source link

The library does not handle buffers properly #228

Open abhash-er opened 2 months ago

abhash-er commented 2 months ago

All of the parameters registered inside thop do not account for the model's device. This makes buffers made by thop to be in CPU, while model parameters are in GPU, which leads to an error. Thus, it requires a frequent call to move the model to cuda during a forward pass (if flops of a model are changing in case of a technique that prunes the model).

It would be nice if initialization of all buffers (inside all hooks) are initialized from a device, that can be passed inside the profile function. Alternatively, one could also get the device from the model/input passed to the profile function.