TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.56k stars 119 forks source link

summary does not work for CLIP models (jit) #56

Open sgimmini opened 3 years ago

sgimmini commented 3 years ago

Trying to summarize a CLIP model results in an output where there are no values listed in the Output Shape column:

==========================================================================================                                                                                                                           
Layer (type:depth-idx)                   Output Shape              Param #                                                                                                                                           
==========================================================================================                                                                                                                           
ModifiedResNet                           --                        --                                                                                                                                                
├─Conv2d: 1-1                            --                        864                                                                                                                                               
├─BatchNorm2d: 1-2                       --                        64                                                                                                                                                
...

Installed torchinfo through pip and the latest version of this git-repo. Probably refers to #55, as they use jit as well, but does not look solved. As #55 does not provide in/outputs (after the fix), I'm not sure, if this issue is just a duplicate. Also, not sure if this is a bug or a feature-request, sorry if wrongly assigned.

Code to reproduce:

import clip 
from torchinfo import summary

model, preprocess = clip.load("RN50", device="cuda", jit=True)
summary(model.visual, input_size=(32, 3, 224, 224))
TylerYep commented 3 years ago

This is a feature request, in #55 I mentioned there is only preliminary support for JIT (as in no crashes).

I am unfamiliar with the mechanics of JIT, so maybe you have some ideas on how I can improve this - is there a different way to calculate the input.output sizes in JIT models, since module.register_forward_pre_hook(pre_hook) and module.register_forward_hook(hook) don't work?

mert-kurttutan commented 2 years ago

Just came across this one while using CLIP. It should be working now, with the latest version.

My output when I try the @smn57 's code:


==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ModifiedResNet                           [32, 1024]                --
├─Conv2d: 1-1                            [32, 32, 112, 112]        864
├─BatchNorm2d: 1-2                       [32, 32, 112, 112]        64
├─ReLU: 1-3                              [32, 32, 112, 112]        --
├─Conv2d: 1-4                            [32, 32, 112, 112]        9,216
├─BatchNorm2d: 1-5                       [32, 32, 112, 112]        64
├─ReLU: 1-6                              [32, 32, 112, 112]        --
├─Conv2d: 1-7                            [32, 64, 112, 112]        18,432
├─BatchNorm2d: 1-8                       [32, 64, 112, 112]        128
├─ReLU: 1-9                              [32, 64, 112, 112]        --
├─AvgPool2d: 1-10                        [32, 64, 56, 56]          --
├─Sequential: 1-11                       [32, 256, 56, 56]         --
│    └─Bottleneck: 2-1                   [32, 256, 56, 56]         --
│    │    └─Conv2d: 3-1                  [32, 64, 56, 56]          4,096
│    │    └─BatchNorm2d: 3-2             [32, 64, 56, 56]          128
│    │    └─ReLU: 3-3                    [32, 64, 56, 56]          --
│    │    └─Conv2d: 3-4                  [32, 64, 56, 56]          36,864
│    │    └─BatchNorm2d: 3-5             [32, 64, 56, 56]          128
│    │    └─ReLU: 3-6                    [32, 64, 56, 56]          --
│    │    └─Identity: 3-7                [32, 64, 56, 56]          --
│    │    └─Conv2d: 3-8                  [32, 256, 56, 56]         16,384
│    │    └─BatchNorm2d: 3-9             [32, 256, 56, 56]         512
│    │    └─Sequential: 3-10             [32, 256, 56, 56]         16,896
│    │    └─ReLU: 3-11                   [32, 256, 56, 56]         --
│    └─Bottleneck: 2-2                   [32, 256, 56, 56]         --
│    │    └─Conv2d: 3-12                 [32, 64, 56, 56]          16,384
│    │    └─BatchNorm2d: 3-13            [32, 64, 56, 56]          128
│    │    └─ReLU: 3-14                   [32, 64, 56, 56]          --
│    │    └─Conv2d: 3-15                 [32, 64, 56, 56]          36,864
│    │    └─BatchNorm2d: 3-16            [32, 64, 56, 56]          128
│    │    └─ReLU: 3-17                   [32, 64, 56, 56]          --
│    │    └─Identity: 3-18               [32, 64, 56, 56]          --
│    │    └─Conv2d: 3-19                 [32, 256, 56, 56]         16,384
│    │    └─BatchNorm2d: 3-20            [32, 256, 56, 56]         512
│    │    └─ReLU: 3-21                   [32, 256, 56, 56]         --
│    └─Bottleneck: 2-3                   [32, 256, 56, 56]         --
│    │    └─Conv2d: 3-22                 [32, 64, 56, 56]          16,384
│    │    └─BatchNorm2d: 3-23            [32, 64, 56, 56]          128
│    │    └─ReLU: 3-24                   [32, 64, 56, 56]          --
│    │    └─Conv2d: 3-25                 [32, 64, 56, 56]          36,864
│    │    └─BatchNorm2d: 3-26            [32, 64, 56, 56]          128
│    │    └─ReLU: 3-27                   [32, 64, 56, 56]          --
│    │    └─Identity: 3-28               [32, 64, 56, 56]          --
│    │    └─Conv2d: 3-29                 [32, 256, 56, 56]         16,384
│    │    └─BatchNorm2d: 3-30            [32, 256, 56, 56]         512
│    │    └─ReLU: 3-31                   [32, 256, 56, 56]         --
├─Sequential: 1-12                       [32, 512, 28, 28]         --
│    └─Bottleneck: 2-4                   [32, 512, 28, 28]         --
│    │    └─Conv2d: 3-32                 [32, 128, 56, 56]         32,768
│    │    └─BatchNorm2d: 3-33            [32, 128, 56, 56]         256
│    │    └─ReLU: 3-34                   [32, 128, 56, 56]         --
│    │    └─Conv2d: 3-35                 [32, 128, 56, 56]         147,456
│    │    └─BatchNorm2d: 3-36            [32, 128, 56, 56]         256
│    │    └─ReLU: 3-37                   [32, 128, 56, 56]         --
│    │    └─AvgPool2d: 3-38              [32, 128, 28, 28]         --
│    │    └─Conv2d: 3-39                 [32, 512, 28, 28]         65,536
│    │    └─BatchNorm2d: 3-40            [32, 512, 28, 28]         1,024
│    │    └─Sequential: 3-41             [32, 512, 28, 28]         132,096
│    │    └─ReLU: 3-42                   [32, 512, 28, 28]         --
│    └─Bottleneck: 2-5                   [32, 512, 28, 28]         --
│    │    └─Conv2d: 3-43                 [32, 128, 28, 28]         65,536
│    │    └─BatchNorm2d: 3-44            [32, 128, 28, 28]         256
│    │    └─ReLU: 3-45                   [32, 128, 28, 28]         --
│    │    └─Conv2d: 3-46                 [32, 128, 28, 28]         147,456
│    │    └─BatchNorm2d: 3-47            [32, 128, 28, 28]         256
│    │    └─ReLU: 3-48                   [32, 128, 28, 28]         --
│    │    └─Identity: 3-49               [32, 128, 28, 28]         --
│    │    └─Conv2d: 3-50                 [32, 512, 28, 28]         65,536
│    │    └─BatchNorm2d: 3-51            [32, 512, 28, 28]         1,024
│    │    └─ReLU: 3-52                   [32, 512, 28, 28]         --
│    └─Bottleneck: 2-6                   [32, 512, 28, 28]         --
│    │    └─Conv2d: 3-53                 [32, 128, 28, 28]         65,536
│    │    └─BatchNorm2d: 3-54            [32, 128, 28, 28]         256
│    │    └─ReLU: 3-55                   [32, 128, 28, 28]         --
│    │    └─Conv2d: 3-56                 [32, 128, 28, 28]         147,456
│    │    └─BatchNorm2d: 3-57            [32, 128, 28, 28]         256
│    │    └─ReLU: 3-58                   [32, 128, 28, 28]         --
│    │    └─Identity: 3-59               [32, 128, 28, 28]         --
│    │    └─Conv2d: 3-60                 [32, 512, 28, 28]         65,536
│    │    └─BatchNorm2d: 3-61            [32, 512, 28, 28]         1,024
│    │    └─ReLU: 3-62                   [32, 512, 28, 28]         --
│    └─Bottleneck: 2-7                   [32, 512, 28, 28]         --
│    │    └─Conv2d: 3-63                 [32, 128, 28, 28]         65,536
│    │    └─BatchNorm2d: 3-64            [32, 128, 28, 28]         256
│    │    └─ReLU: 3-65                   [32, 128, 28, 28]         --
│    │    └─Conv2d: 3-66                 [32, 128, 28, 28]         147,456
│    │    └─BatchNorm2d: 3-67            [32, 128, 28, 28]         256
│    │    └─ReLU: 3-68                   [32, 128, 28, 28]         --
│    │    └─Identity: 3-69               [32, 128, 28, 28]         --
│    │    └─Conv2d: 3-70                 [32, 512, 28, 28]         65,536
│    │    └─BatchNorm2d: 3-71            [32, 512, 28, 28]         1,024
│    │    └─ReLU: 3-72                   [32, 512, 28, 28]         --
├─Sequential: 1-13                       [32, 1024, 14, 14]        --
│    └─Bottleneck: 2-8                   [32, 1024, 14, 14]        --
│    │    └─Conv2d: 3-73                 [32, 256, 28, 28]         131,072
│    │    └─BatchNorm2d: 3-74            [32, 256, 28, 28]         512
│    │    └─ReLU: 3-75                   [32, 256, 28, 28]         --
│    │    └─Conv2d: 3-76                 [32, 256, 28, 28]         589,824
│    │    └─BatchNorm2d: 3-77            [32, 256, 28, 28]         512
│    │    └─ReLU: 3-78                   [32, 256, 28, 28]         --
│    │    └─AvgPool2d: 3-79              [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-80                 [32, 1024, 14, 14]        262,144
│    │    └─BatchNorm2d: 3-81            [32, 1024, 14, 14]        2,048
│    │    └─Sequential: 3-82             [32, 1024, 14, 14]        526,336
│    │    └─ReLU: 3-83                   [32, 1024, 14, 14]        --
│    └─Bottleneck: 2-9                   [32, 1024, 14, 14]        --
│    │    └─Conv2d: 3-84                 [32, 256, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-85            [32, 256, 14, 14]         512
│    │    └─ReLU: 3-86                   [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-87                 [32, 256, 14, 14]         589,824
│    │    └─BatchNorm2d: 3-88            [32, 256, 14, 14]         512
│    │    └─ReLU: 3-89                   [32, 256, 14, 14]         --
│    │    └─Identity: 3-90               [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-91                 [32, 1024, 14, 14]        262,144
│    │    └─BatchNorm2d: 3-92            [32, 1024, 14, 14]        2,048
│    │    └─ReLU: 3-93                   [32, 1024, 14, 14]        --
│    └─Bottleneck: 2-10                  [32, 1024, 14, 14]        --
│    │    └─Conv2d: 3-94                 [32, 256, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-95            [32, 256, 14, 14]         512
│    │    └─ReLU: 3-96                   [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-97                 [32, 256, 14, 14]         589,824
│    │    └─BatchNorm2d: 3-98            [32, 256, 14, 14]         512
│    │    └─ReLU: 3-99                   [32, 256, 14, 14]         --
│    │    └─Identity: 3-100              [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-101                [32, 1024, 14, 14]        262,144
│    │    └─BatchNorm2d: 3-102           [32, 1024, 14, 14]        2,048
│    │    └─ReLU: 3-103                  [32, 1024, 14, 14]        --
│    └─Bottleneck: 2-11                  [32, 1024, 14, 14]        --
│    │    └─Conv2d: 3-104                [32, 256, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-105           [32, 256, 14, 14]         512
│    │    └─ReLU: 3-106                  [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-107                [32, 256, 14, 14]         589,824
│    │    └─BatchNorm2d: 3-108           [32, 256, 14, 14]         512
│    │    └─ReLU: 3-109                  [32, 256, 14, 14]         --
│    │    └─Identity: 3-110              [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-111                [32, 1024, 14, 14]        262,144
│    │    └─BatchNorm2d: 3-112           [32, 1024, 14, 14]        2,048
│    │    └─ReLU: 3-113                  [32, 1024, 14, 14]        --
│    └─Bottleneck: 2-12                  [32, 1024, 14, 14]        --
│    │    └─Conv2d: 3-114                [32, 256, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-115           [32, 256, 14, 14]         512
│    │    └─ReLU: 3-116                  [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-117                [32, 256, 14, 14]         589,824
│    │    └─BatchNorm2d: 3-118           [32, 256, 14, 14]         512
│    │    └─ReLU: 3-119                  [32, 256, 14, 14]         --
│    │    └─Identity: 3-120              [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-121                [32, 1024, 14, 14]        262,144
│    │    └─BatchNorm2d: 3-122           [32, 1024, 14, 14]        2,048
│    │    └─ReLU: 3-123                  [32, 1024, 14, 14]        --
│    └─Bottleneck: 2-13                  [32, 1024, 14, 14]        --
│    │    └─Conv2d: 3-124                [32, 256, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-125           [32, 256, 14, 14]         512
│    │    └─ReLU: 3-126                  [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-127                [32, 256, 14, 14]         589,824
│    │    └─BatchNorm2d: 3-128           [32, 256, 14, 14]         512
│    │    └─ReLU: 3-129                  [32, 256, 14, 14]         --
│    │    └─Identity: 3-130              [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-131                [32, 1024, 14, 14]        262,144
│    │    └─BatchNorm2d: 3-132           [32, 1024, 14, 14]        2,048
│    │    └─ReLU: 3-133                  [32, 1024, 14, 14]        --
├─Sequential: 1-14                       [32, 2048, 7, 7]          --
│    └─Bottleneck: 2-14                  [32, 2048, 7, 7]          --
│    │    └─Conv2d: 3-134                [32, 512, 14, 14]         524,288
│    │    └─BatchNorm2d: 3-135           [32, 512, 14, 14]         1,024
│    │    └─ReLU: 3-136                  [32, 512, 14, 14]         --
│    │    └─Conv2d: 3-137                [32, 512, 14, 14]         2,359,296
│    │    └─BatchNorm2d: 3-138           [32, 512, 14, 14]         1,024
│    │    └─ReLU: 3-139                  [32, 512, 14, 14]         --
│    │    └─AvgPool2d: 3-140             [32, 512, 7, 7]           --
│    │    └─Conv2d: 3-141                [32, 2048, 7, 7]          1,048,576
│    │    └─BatchNorm2d: 3-142           [32, 2048, 7, 7]          4,096
│    │    └─Sequential: 3-143            [32, 2048, 7, 7]          2,101,248
│    │    └─ReLU: 3-144                  [32, 2048, 7, 7]          --
│    └─Bottleneck: 2-15                  [32, 2048, 7, 7]          --
│    │    └─Conv2d: 3-145                [32, 512, 7, 7]           1,048,576
│    │    └─BatchNorm2d: 3-146           [32, 512, 7, 7]           1,024
│    │    └─ReLU: 3-147                  [32, 512, 7, 7]           --
│    │    └─Conv2d: 3-148                [32, 512, 7, 7]           2,359,296
│    │    └─BatchNorm2d: 3-149           [32, 512, 7, 7]           1,024
│    │    └─ReLU: 3-150                  [32, 512, 7, 7]           --
│    │    └─Identity: 3-151              [32, 512, 7, 7]           --
│    │    └─Conv2d: 3-152                [32, 2048, 7, 7]          1,048,576
│    │    └─BatchNorm2d: 3-153           [32, 2048, 7, 7]          4,096
│    │    └─ReLU: 3-154                  [32, 2048, 7, 7]          --
│    └─Bottleneck: 2-16                  [32, 2048, 7, 7]          --
│    │    └─Conv2d: 3-155                [32, 512, 7, 7]           1,048,576
│    │    └─BatchNorm2d: 3-156           [32, 512, 7, 7]           1,024
│    │    └─ReLU: 3-157                  [32, 512, 7, 7]           --
│    │    └─Conv2d: 3-158                [32, 512, 7, 7]           2,359,296
│    │    └─BatchNorm2d: 3-159           [32, 512, 7, 7]           1,024
│    │    └─ReLU: 3-160                  [32, 512, 7, 7]           --
│    │    └─Identity: 3-161              [32, 512, 7, 7]           --
│    │    └─Conv2d: 3-162                [32, 2048, 7, 7]          1,048,576
│    │    └─BatchNorm2d: 3-163           [32, 2048, 7, 7]          4,096
│    │    └─ReLU: 3-164                  [32, 2048, 7, 7]          --
├─AttentionPool2d: 1-15                  [32, 1024]                14,789,632
==========================================================================================
Total params: 38,316,896
Trainable params: 38,316,896
Non-trainable params: 0
Total mult-adds (G): 171.75
==========================================================================================
Input size (MB): 19.27
Forward/backward pass size (MB): 3185.57
Params size (MB): 47.16
Estimated Total Size (MB): 3252.00
==========================================================================================
TylerYep commented 2 years ago

It looks like CLIP switched to using non-JIT by default, which is why this code started working on recent versions: https://github.com/openai/CLIP/commit/db20393f4affd4158528bd868478e516ebed0944

@mert-kurttutan Loading the model with clip.load(..., jit=True) would likely resurface the original issue?

I'll edit the original question to specify this.

mert-kurttutan commented 2 years ago

It looks like CLIP switched to using non-JIT by default, which is why this code started working on recent versions: openai/CLIP@db20393

@mert-kurttutan Loading the model with clip.load(..., jit=True) would likely resurface the original issue?

I'll edit the original question to specify this.

You are right. It appears again with jit.