Lyken17 / pytorch-OpCounter

Count the MACs / FLOPs of your PyTorch model.
MIT License
4.9k stars 528 forks source link

I got different results using thop and torchinfo #224

Open zz-2024 opened 5 months ago

zz-2024 commented 5 months ago

my code:

import torch 
from PIL import Image
from thop import profile, clever_format
from torchinfo import summary
import cn_clip.clip as clip
from cn_clip.clip import load_from_name, available_models
print("Available models:", available_models())  
# Available models: ['ViT-B-16', 'ViT-L-14', 'ViT-L-14-336', 'ViT-H-14', 'RN50']

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = load_from_name("RN50", device=device, download_root='./')
model.eval()
image = preprocess(Image.open("examples/pokemon.jpeg")).unsqueeze(0).to(device)
text = clip.tokenize(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]).to(device)
print(f"device:{device}")
with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    # 对特征进行归一化,请使用归一化后的图文特征用于下游任务
    image_features /= image_features.norm(dim=-1, keepdim=True) 
    text_features /= text_features.norm(dim=-1, keepdim=True)    

    logits_per_image, logits_per_text = model.get_similarity(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # [[1.268734e-03 5.436878e-02 6.795761e-04 9.436829e-01]]

# 计算FLOPs
print(f"image:{image.shape}, text:{text.shape}")
flops, params = profile(model, inputs=(image, text), verbose=False)
flops, params = clever_format([flops, params], "%.3f")
print(f"Model FLOPs: {flops}, Model Params:{params}")
print("========seperate==========")
flops, params = profile(model.visual, inputs=(image, ), verbose=False)
flops, params = clever_format([flops, params], "%.3f")
print(f"Model FLOPs: {flops}, Model Params:{params}")
print(f"=========another method=============")
summary(model, input_data=[image, text])

my results:

Model FLOPs: 9.839G, Model Params:44.792M
========seperate==========
Model FLOPs: 5.418G, Model Params:23.527M
=========another method=============
=========================================================================================================
Layer (type:depth-idx)                                  Output Shape              Param #
=========================================================================================================
CLIP                                                    [1, 1024]                 786,433
├─ModifiedResNet: 1-1                                   [1, 1024]                 --
│    └─Conv2d: 2-1                                      [1, 32, 112, 112]         864
│    └─BatchNorm2d: 2-2                                 [1, 32, 112, 112]         64
│    └─ReLU: 2-3                                        [1, 32, 112, 112]         --
│    └─Conv2d: 2-4                                      [1, 32, 112, 112]         9,216
│    └─BatchNorm2d: 2-5                                 [1, 32, 112, 112]         64
│    └─ReLU: 2-6                                        [1, 32, 112, 112]         --
│    └─Conv2d: 2-7                                      [1, 64, 112, 112]         18,432
│    └─BatchNorm2d: 2-8                                 [1, 64, 112, 112]         128
│    └─ReLU: 2-9                                        [1, 64, 112, 112]         --
│    └─AvgPool2d: 2-10                                  [1, 64, 56, 56]           --
│    └─Sequential: 2-11                                 [1, 256, 56, 56]          --
│    │    └─Bottleneck: 3-1                             [1, 256, 56, 56]          75,008
│    │    └─Bottleneck: 3-2                             [1, 256, 56, 56]          70,400
│    │    └─Bottleneck: 3-3                             [1, 256, 56, 56]          70,400
│    └─Sequential: 2-12                                 [1, 512, 28, 28]          --
│    │    └─Bottleneck: 3-4                             [1, 512, 28, 28]          379,392
│    │    └─Bottleneck: 3-5                             [1, 512, 28, 28]          280,064
│    │    └─Bottleneck: 3-6                             [1, 512, 28, 28]          280,064
│    │    └─Bottleneck: 3-7                             [1, 512, 28, 28]          280,064
│    └─Sequential: 2-13                                 [1, 1024, 14, 14]         --
│    │    └─Bottleneck: 3-8                             [1, 1024, 14, 14]         1,512,448
│    │    └─Bottleneck: 3-9                             [1, 1024, 14, 14]         1,117,184
│    │    └─Bottleneck: 3-10                            [1, 1024, 14, 14]         1,117,184
│    │    └─Bottleneck: 3-11                            [1, 1024, 14, 14]         1,117,184
│    │    └─Bottleneck: 3-12                            [1, 1024, 14, 14]         1,117,184
│    │    └─Bottleneck: 3-13                            [1, 1024, 14, 14]         1,117,184
│    └─Sequential: 2-14                                 [1, 2048, 7, 7]           --
│    │    └─Bottleneck: 3-14                            [1, 2048, 7, 7]           6,039,552
│    │    └─Bottleneck: 3-15                            [1, 2048, 7, 7]           4,462,592
│    │    └─Bottleneck: 3-16                            [1, 2048, 7, 7]           4,462,592
│    └─AttentionPool2d: 2-15                            [1, 1024]                 14,789,632
├─BertModel: 1-2                                        [4, 52, 768]              --
│    └─BertEmbeddings: 2-16                             [4, 52, 768]              --
│    │    └─Embedding: 3-17                             [4, 52, 768]              16,226,304
│    │    └─Embedding: 3-18                             [4, 52, 768]              393,216
│    │    └─Embedding: 3-19                             [4, 52, 768]              1,536
│    │    └─LayerNorm: 3-20                             [4, 52, 768]              1,536
│    │    └─Dropout: 3-21                               [4, 52, 768]              --
│    └─BertEncoder: 2-17                                [4, 52, 768]              --
│    │    └─ModuleList: 3-22                            --                        21,263,616
=========================================================================================================
Total params: 76,989,537
Trainable params: 76,989,537
Non-trainable params: 0
Total mult-adds (G): 5.52
=========================================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 123.19
Params size (MB): 122.93
Estimated Total Size (MB): 246.73
=========================================================================================================

Why the number of parameters varies so much? (44.792M VS 77M) Why are mult-adds such different? (9.839G VS 5.52G)

zz-2024 commented 5 months ago

And if I use sum(x.numel() for x in model.parameters()), I got 77M parameters