zhijian-liu / torchprofile

A general and accurate MACs / FLOPs profiler for PyTorch models
https://pypi.org/project/torchprofile/
MIT License
560 stars 38 forks source link
profiler pytorch

Torchprofile

This is a profiler to count the number of MACs / FLOPs of PyTorch models based on torch.jit.trace.

Installation

pip install torchprofile

Getting Started

You should first define your PyTorch model and its (dummy) input:

import torch
from torchvision.models import resnet18

model = resnet18()
inputs = torch.randn(1, 3, 224, 224)

You can then measure the number of MACs using profile_macs:

from torchprofile import profile_macs

macs = profile_macs(model, inputs)

License

This repository is released under the MIT license. See LICENSE for additional details.