Open sgimmini opened 3 years ago
This is a feature request, in #55 I mentioned there is only preliminary support for JIT (as in no crashes).
I am unfamiliar with the mechanics of JIT, so maybe you have some ideas on how I can improve this - is there a different way to calculate the input.output sizes in JIT models, since module.register_forward_pre_hook(pre_hook)
and module.register_forward_hook(hook)
don't work?
Just came across this one while using CLIP. It should be working now, with the latest version.
My output when I try the @smn57 's code:
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
ModifiedResNet [32, 1024] --
├─Conv2d: 1-1 [32, 32, 112, 112] 864
├─BatchNorm2d: 1-2 [32, 32, 112, 112] 64
├─ReLU: 1-3 [32, 32, 112, 112] --
├─Conv2d: 1-4 [32, 32, 112, 112] 9,216
├─BatchNorm2d: 1-5 [32, 32, 112, 112] 64
├─ReLU: 1-6 [32, 32, 112, 112] --
├─Conv2d: 1-7 [32, 64, 112, 112] 18,432
├─BatchNorm2d: 1-8 [32, 64, 112, 112] 128
├─ReLU: 1-9 [32, 64, 112, 112] --
├─AvgPool2d: 1-10 [32, 64, 56, 56] --
├─Sequential: 1-11 [32, 256, 56, 56] --
│ └─Bottleneck: 2-1 [32, 256, 56, 56] --
│ │ └─Conv2d: 3-1 [32, 64, 56, 56] 4,096
│ │ └─BatchNorm2d: 3-2 [32, 64, 56, 56] 128
│ │ └─ReLU: 3-3 [32, 64, 56, 56] --
│ │ └─Conv2d: 3-4 [32, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-5 [32, 64, 56, 56] 128
│ │ └─ReLU: 3-6 [32, 64, 56, 56] --
│ │ └─Identity: 3-7 [32, 64, 56, 56] --
│ │ └─Conv2d: 3-8 [32, 256, 56, 56] 16,384
│ │ └─BatchNorm2d: 3-9 [32, 256, 56, 56] 512
│ │ └─Sequential: 3-10 [32, 256, 56, 56] 16,896
│ │ └─ReLU: 3-11 [32, 256, 56, 56] --
│ └─Bottleneck: 2-2 [32, 256, 56, 56] --
│ │ └─Conv2d: 3-12 [32, 64, 56, 56] 16,384
│ │ └─BatchNorm2d: 3-13 [32, 64, 56, 56] 128
│ │ └─ReLU: 3-14 [32, 64, 56, 56] --
│ │ └─Conv2d: 3-15 [32, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-16 [32, 64, 56, 56] 128
│ │ └─ReLU: 3-17 [32, 64, 56, 56] --
│ │ └─Identity: 3-18 [32, 64, 56, 56] --
│ │ └─Conv2d: 3-19 [32, 256, 56, 56] 16,384
│ │ └─BatchNorm2d: 3-20 [32, 256, 56, 56] 512
│ │ └─ReLU: 3-21 [32, 256, 56, 56] --
│ └─Bottleneck: 2-3 [32, 256, 56, 56] --
│ │ └─Conv2d: 3-22 [32, 64, 56, 56] 16,384
│ │ └─BatchNorm2d: 3-23 [32, 64, 56, 56] 128
│ │ └─ReLU: 3-24 [32, 64, 56, 56] --
│ │ └─Conv2d: 3-25 [32, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-26 [32, 64, 56, 56] 128
│ │ └─ReLU: 3-27 [32, 64, 56, 56] --
│ │ └─Identity: 3-28 [32, 64, 56, 56] --
│ │ └─Conv2d: 3-29 [32, 256, 56, 56] 16,384
│ │ └─BatchNorm2d: 3-30 [32, 256, 56, 56] 512
│ │ └─ReLU: 3-31 [32, 256, 56, 56] --
├─Sequential: 1-12 [32, 512, 28, 28] --
│ └─Bottleneck: 2-4 [32, 512, 28, 28] --
│ │ └─Conv2d: 3-32 [32, 128, 56, 56] 32,768
│ │ └─BatchNorm2d: 3-33 [32, 128, 56, 56] 256
│ │ └─ReLU: 3-34 [32, 128, 56, 56] --
│ │ └─Conv2d: 3-35 [32, 128, 56, 56] 147,456
│ │ └─BatchNorm2d: 3-36 [32, 128, 56, 56] 256
│ │ └─ReLU: 3-37 [32, 128, 56, 56] --
│ │ └─AvgPool2d: 3-38 [32, 128, 28, 28] --
│ │ └─Conv2d: 3-39 [32, 512, 28, 28] 65,536
│ │ └─BatchNorm2d: 3-40 [32, 512, 28, 28] 1,024
│ │ └─Sequential: 3-41 [32, 512, 28, 28] 132,096
│ │ └─ReLU: 3-42 [32, 512, 28, 28] --
│ └─Bottleneck: 2-5 [32, 512, 28, 28] --
│ │ └─Conv2d: 3-43 [32, 128, 28, 28] 65,536
│ │ └─BatchNorm2d: 3-44 [32, 128, 28, 28] 256
│ │ └─ReLU: 3-45 [32, 128, 28, 28] --
│ │ └─Conv2d: 3-46 [32, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-47 [32, 128, 28, 28] 256
│ │ └─ReLU: 3-48 [32, 128, 28, 28] --
│ │ └─Identity: 3-49 [32, 128, 28, 28] --
│ │ └─Conv2d: 3-50 [32, 512, 28, 28] 65,536
│ │ └─BatchNorm2d: 3-51 [32, 512, 28, 28] 1,024
│ │ └─ReLU: 3-52 [32, 512, 28, 28] --
│ └─Bottleneck: 2-6 [32, 512, 28, 28] --
│ │ └─Conv2d: 3-53 [32, 128, 28, 28] 65,536
│ │ └─BatchNorm2d: 3-54 [32, 128, 28, 28] 256
│ │ └─ReLU: 3-55 [32, 128, 28, 28] --
│ │ └─Conv2d: 3-56 [32, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-57 [32, 128, 28, 28] 256
│ │ └─ReLU: 3-58 [32, 128, 28, 28] --
│ │ └─Identity: 3-59 [32, 128, 28, 28] --
│ │ └─Conv2d: 3-60 [32, 512, 28, 28] 65,536
│ │ └─BatchNorm2d: 3-61 [32, 512, 28, 28] 1,024
│ │ └─ReLU: 3-62 [32, 512, 28, 28] --
│ └─Bottleneck: 2-7 [32, 512, 28, 28] --
│ │ └─Conv2d: 3-63 [32, 128, 28, 28] 65,536
│ │ └─BatchNorm2d: 3-64 [32, 128, 28, 28] 256
│ │ └─ReLU: 3-65 [32, 128, 28, 28] --
│ │ └─Conv2d: 3-66 [32, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-67 [32, 128, 28, 28] 256
│ │ └─ReLU: 3-68 [32, 128, 28, 28] --
│ │ └─Identity: 3-69 [32, 128, 28, 28] --
│ │ └─Conv2d: 3-70 [32, 512, 28, 28] 65,536
│ │ └─BatchNorm2d: 3-71 [32, 512, 28, 28] 1,024
│ │ └─ReLU: 3-72 [32, 512, 28, 28] --
├─Sequential: 1-13 [32, 1024, 14, 14] --
│ └─Bottleneck: 2-8 [32, 1024, 14, 14] --
│ │ └─Conv2d: 3-73 [32, 256, 28, 28] 131,072
│ │ └─BatchNorm2d: 3-74 [32, 256, 28, 28] 512
│ │ └─ReLU: 3-75 [32, 256, 28, 28] --
│ │ └─Conv2d: 3-76 [32, 256, 28, 28] 589,824
│ │ └─BatchNorm2d: 3-77 [32, 256, 28, 28] 512
│ │ └─ReLU: 3-78 [32, 256, 28, 28] --
│ │ └─AvgPool2d: 3-79 [32, 256, 14, 14] --
│ │ └─Conv2d: 3-80 [32, 1024, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-81 [32, 1024, 14, 14] 2,048
│ │ └─Sequential: 3-82 [32, 1024, 14, 14] 526,336
│ │ └─ReLU: 3-83 [32, 1024, 14, 14] --
│ └─Bottleneck: 2-9 [32, 1024, 14, 14] --
│ │ └─Conv2d: 3-84 [32, 256, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-85 [32, 256, 14, 14] 512
│ │ └─ReLU: 3-86 [32, 256, 14, 14] --
│ │ └─Conv2d: 3-87 [32, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-88 [32, 256, 14, 14] 512
│ │ └─ReLU: 3-89 [32, 256, 14, 14] --
│ │ └─Identity: 3-90 [32, 256, 14, 14] --
│ │ └─Conv2d: 3-91 [32, 1024, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-92 [32, 1024, 14, 14] 2,048
│ │ └─ReLU: 3-93 [32, 1024, 14, 14] --
│ └─Bottleneck: 2-10 [32, 1024, 14, 14] --
│ │ └─Conv2d: 3-94 [32, 256, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-95 [32, 256, 14, 14] 512
│ │ └─ReLU: 3-96 [32, 256, 14, 14] --
│ │ └─Conv2d: 3-97 [32, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-98 [32, 256, 14, 14] 512
│ │ └─ReLU: 3-99 [32, 256, 14, 14] --
│ │ └─Identity: 3-100 [32, 256, 14, 14] --
│ │ └─Conv2d: 3-101 [32, 1024, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-102 [32, 1024, 14, 14] 2,048
│ │ └─ReLU: 3-103 [32, 1024, 14, 14] --
│ └─Bottleneck: 2-11 [32, 1024, 14, 14] --
│ │ └─Conv2d: 3-104 [32, 256, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-105 [32, 256, 14, 14] 512
│ │ └─ReLU: 3-106 [32, 256, 14, 14] --
│ │ └─Conv2d: 3-107 [32, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-108 [32, 256, 14, 14] 512
│ │ └─ReLU: 3-109 [32, 256, 14, 14] --
│ │ └─Identity: 3-110 [32, 256, 14, 14] --
│ │ └─Conv2d: 3-111 [32, 1024, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-112 [32, 1024, 14, 14] 2,048
│ │ └─ReLU: 3-113 [32, 1024, 14, 14] --
│ └─Bottleneck: 2-12 [32, 1024, 14, 14] --
│ │ └─Conv2d: 3-114 [32, 256, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-115 [32, 256, 14, 14] 512
│ │ └─ReLU: 3-116 [32, 256, 14, 14] --
│ │ └─Conv2d: 3-117 [32, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-118 [32, 256, 14, 14] 512
│ │ └─ReLU: 3-119 [32, 256, 14, 14] --
│ │ └─Identity: 3-120 [32, 256, 14, 14] --
│ │ └─Conv2d: 3-121 [32, 1024, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-122 [32, 1024, 14, 14] 2,048
│ │ └─ReLU: 3-123 [32, 1024, 14, 14] --
│ └─Bottleneck: 2-13 [32, 1024, 14, 14] --
│ │ └─Conv2d: 3-124 [32, 256, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-125 [32, 256, 14, 14] 512
│ │ └─ReLU: 3-126 [32, 256, 14, 14] --
│ │ └─Conv2d: 3-127 [32, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-128 [32, 256, 14, 14] 512
│ │ └─ReLU: 3-129 [32, 256, 14, 14] --
│ │ └─Identity: 3-130 [32, 256, 14, 14] --
│ │ └─Conv2d: 3-131 [32, 1024, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-132 [32, 1024, 14, 14] 2,048
│ │ └─ReLU: 3-133 [32, 1024, 14, 14] --
├─Sequential: 1-14 [32, 2048, 7, 7] --
│ └─Bottleneck: 2-14 [32, 2048, 7, 7] --
│ │ └─Conv2d: 3-134 [32, 512, 14, 14] 524,288
│ │ └─BatchNorm2d: 3-135 [32, 512, 14, 14] 1,024
│ │ └─ReLU: 3-136 [32, 512, 14, 14] --
│ │ └─Conv2d: 3-137 [32, 512, 14, 14] 2,359,296
│ │ └─BatchNorm2d: 3-138 [32, 512, 14, 14] 1,024
│ │ └─ReLU: 3-139 [32, 512, 14, 14] --
│ │ └─AvgPool2d: 3-140 [32, 512, 7, 7] --
│ │ └─Conv2d: 3-141 [32, 2048, 7, 7] 1,048,576
│ │ └─BatchNorm2d: 3-142 [32, 2048, 7, 7] 4,096
│ │ └─Sequential: 3-143 [32, 2048, 7, 7] 2,101,248
│ │ └─ReLU: 3-144 [32, 2048, 7, 7] --
│ └─Bottleneck: 2-15 [32, 2048, 7, 7] --
│ │ └─Conv2d: 3-145 [32, 512, 7, 7] 1,048,576
│ │ └─BatchNorm2d: 3-146 [32, 512, 7, 7] 1,024
│ │ └─ReLU: 3-147 [32, 512, 7, 7] --
│ │ └─Conv2d: 3-148 [32, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-149 [32, 512, 7, 7] 1,024
│ │ └─ReLU: 3-150 [32, 512, 7, 7] --
│ │ └─Identity: 3-151 [32, 512, 7, 7] --
│ │ └─Conv2d: 3-152 [32, 2048, 7, 7] 1,048,576
│ │ └─BatchNorm2d: 3-153 [32, 2048, 7, 7] 4,096
│ │ └─ReLU: 3-154 [32, 2048, 7, 7] --
│ └─Bottleneck: 2-16 [32, 2048, 7, 7] --
│ │ └─Conv2d: 3-155 [32, 512, 7, 7] 1,048,576
│ │ └─BatchNorm2d: 3-156 [32, 512, 7, 7] 1,024
│ │ └─ReLU: 3-157 [32, 512, 7, 7] --
│ │ └─Conv2d: 3-158 [32, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-159 [32, 512, 7, 7] 1,024
│ │ └─ReLU: 3-160 [32, 512, 7, 7] --
│ │ └─Identity: 3-161 [32, 512, 7, 7] --
│ │ └─Conv2d: 3-162 [32, 2048, 7, 7] 1,048,576
│ │ └─BatchNorm2d: 3-163 [32, 2048, 7, 7] 4,096
│ │ └─ReLU: 3-164 [32, 2048, 7, 7] --
├─AttentionPool2d: 1-15 [32, 1024] 14,789,632
==========================================================================================
Total params: 38,316,896
Trainable params: 38,316,896
Non-trainable params: 0
Total mult-adds (G): 171.75
==========================================================================================
Input size (MB): 19.27
Forward/backward pass size (MB): 3185.57
Params size (MB): 47.16
Estimated Total Size (MB): 3252.00
==========================================================================================
It looks like CLIP switched to using non-JIT by default, which is why this code started working on recent versions: https://github.com/openai/CLIP/commit/db20393f4affd4158528bd868478e516ebed0944
@mert-kurttutan Loading the model with clip.load(..., jit=True)
would likely resurface the original issue?
I'll edit the original question to specify this.
It looks like CLIP switched to using non-JIT by default, which is why this code started working on recent versions: openai/CLIP@db20393
@mert-kurttutan Loading the model with
clip.load(..., jit=True)
would likely resurface the original issue?I'll edit the original question to specify this.
You are right. It appears again with jit.
Trying to summarize a CLIP model results in an output where there are no values listed in the
Output Shape
column:Installed
torchinfo
through pip and the latest version of this git-repo. Probably refers to #55, as they use jit as well, but does not look solved. As #55 does not provide in/outputs (after the fix), I'm not sure, if this issue is just a duplicate. Also, not sure if this is a bug or a feature-request, sorry if wrongly assigned.Code to reproduce: