TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.4k stars 112 forks source link

estimate model size is different with nvidia-smi usage #187

Open YHYeooooong opened 1 year ago

YHYeooooong commented 1 year ago

Describe the bug estimate model size is different with nvidia-smi usage

To Reproduce

  1. used code, and command line
  2. The code will run on the cuda:2 device
import torch
import torch.nn as nn
import timm
import torchvision

import argparse

from torchinfo import summary

#data load#
device_num = 2
device = torch.device("cuda:"+str(device_num))

num_classes = 2
# model name, augmentation, sche 설정 받기

parser = argparse.ArgumentParser()

parser.add_argument('--model', required=True)
args = parser.parse_args()

# model name, augmentation, sche 설정 받기

each_model = args.model
#best model

model = ''
if each_model == 'CvT-21' :
    model = torch.load('../ref_model/whole_CvT-21-384x384-IN-1k_2class.pt')

elif each_model == 'MLP-Mixer-b16' :
    model = timm.create_model('mixer_b16_224', pretrained=True, num_classes=num_classes)

elif each_model == 'Beit-base-patch16' :
    model = timm.create_model('beit_base_patch16_224', pretrained=True, num_classes=num_classes)

elif each_model == 'ViT-base-16' :
    model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=num_classes)

elif each_model == 'ResNet101' :
    model = timm.create_model('resnet101', pretrained=True, num_classes=num_classes)

elif each_model == 'MobileNetV2' :
    model = timm.create_model('mobilenetv2_100', pretrained=True, num_classes=num_classes)

elif each_model == 'DenseNet121' :
    model = timm.create_model('densenet121', pretrained=True, num_classes=num_classes)

elif each_model == 'EfficientNetB0' :
    model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=num_classes)

if each_model == 'ShuffleNetV2' :
    model = torchvision.models.shufflenet_v2_x1_0(pretrained=True)
    num_f = model.fc.in_features
    model.fc = nn.Linear(num_f, num_classes) #마지막 linear layer의 아웃풋을 2로 만들기

elif each_model == 'gmlp_s16' :
    model = timm.create_model('gmlp_s16_224', pretrained=True, num_classes=num_classes)

elif each_model == 'resmlp_24' :
    model = timm.create_model('resmlp_24_224', pretrained=True, num_classes=num_classes)

elif each_model == 'mobilevit-s' :
    model = timm.create_model('mobilevit_s', pretrained=True, num_classes=num_classes)

elif each_model == 'mobilevit-xs' :
    model = timm.create_model('mobilevit_xs', pretrained=True, num_classes=num_classes)

elif each_model == 'mobilevit-xxs' :
    model = timm.create_model('mobilevit_xxs', pretrained=True, num_classes=num_classes)

model = model.to(device)
model.eval()

summary(model, input_size=(1, 3, 224, 224), mode='eval', device = device)

batch_size = 1
data_shape = (3, 224, 224)
random_data = torch.rand((batch_size, *data_shape)).to(device)

with torch.no_grad():

    outputs = model(random_data)

python img1_test_original_testset_serve_2c_gpumem_forgit.py --model 'MobileNetV2'

Expected behavior nvidia-smi memory usage is same with estimate model size

Screenshots image image

Additional context But those 2 values had different values (over around 1000MB) I also already checked the https://github.com/TylerYep/torchinfo/issues/149#issue-1291452433 , but I could not reproduce similar values with Nvidia-smi GPU usage and the estimated total size of torch info.

Are there any points I missed in the code? or was It really caused by other things, not by my code?

Thanks!

+ In case of shufflenetV2 make same problem python img1_test_original_testset_serve_2c_gpumem_forgit.py --model 'ShuffleNetV2'

image image

+2 I did the simple check for moving data on GPU devices and it yielded this GPU usage

import torch
device_num = 2
device = torch.device("cuda:"+str(device_num))

batch_size = 1
data_shape = (3, 224, 224)
random_data = torch.rand((batch_size, *data_shape)).to(device)

image

Was this might involve in the issue?

mert-kurttutan commented 1 year ago

Could you check if there is any optimizer variable associated with your model (e.g. Adam stores gradient**2 values)? Because these also occupy some memory. But memory estimation provided by torchinfo is only about the input/output/intermediate/gradient/parameter values.

@TylerYep, Maybe, memory estimation should be changed based on this?

Also, do you know if there is jit-compiled part in your model?

Just to make sure, what is gpu usage when you run only summary (without following forward call)? Is it the same as before?

mert-kurttutan commented 1 year ago

Another reason might be that torchinfo does not record pure torch_functions (e.g. adding input and output for skip connection) that are not contained inside a module. So, if you have such a function (that is not inplace), these also lead to memory that is not counted by torchinfo.

YHYeooooong commented 1 year ago

Q : Could you check if there is any optimizer variable associated with your model (e.g. Adam stores gradient**2 values)? Because these also occupy some memory. But memory estimation provided by torchinfo is only about the input/output/intermediate/gradient/parameter values. A : I didn't call the optimizer in the above first code. So they might not use an optimizer in code.

Q : Also, do you know if there is jit-compiled part in your model? A : I not knowing well, but in timm library, it used jit related code, I think.

Q : Just to make sure, what is gpu usage when you run only summary (without following forward call)? Is it the same as before? A : With this code, I got the total estimate size and nvidia-smi GPU usage.

from torchinfo import summary
import torchvision

model = torchvision.models.shufflenet_v2_x1_0(pretrained=True)

summary(model,input_size=(1, 3, 224, 224), mode='eval', device = 'cuda:2')

image image

I understand that some official torch functions are not calculated by the torchinfo summary. Did I get it right?

mert-kurttutan commented 1 year ago

I understand that some official torch functions are not calculated by the torchinfo summary. Did I get it right?

Exactly, for instance,

def forward(self, x):

  input = x

  out = someHugeModule(x)

  out = out + x

  return out

Here, since out is a tensor, last + calls torch.add function, which is not calculated by torchinfo.

Could you also share the entire summary result (with as high depth value as possible)? Because if output shape for some modules is absent, it means some modules are compiled.

To check jit module, maybe you can look at the type of module (and submodules). If they are something like scriptmodule, then it is traced module.

YHYeooooong commented 1 year ago

Thanks!

Could you also share the entire summary result (with as high depth value as possible)? Because if output shape for some modules is absent, it means some modules are compiled. Here is the result of the shufflenetv2 model

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ShuffleNetV2                             [1, 1000]                 --
├─Sequential: 1-1                        [1, 24, 112, 112]         --
│    └─Conv2d: 2-1                       [1, 24, 112, 112]         648
│    └─BatchNorm2d: 2-2                  [1, 24, 112, 112]         48
│    └─ReLU: 2-3                         [1, 24, 112, 112]         --
├─MaxPool2d: 1-2                         [1, 24, 56, 56]           --
├─Sequential: 1-3                        [1, 116, 28, 28]          --
│    └─InvertedResidual: 2-4             [1, 116, 28, 28]          --
│    │    └─Sequential: 3-1              [1, 58, 28, 28]           1,772
│    │    └─Sequential: 3-2              [1, 58, 28, 28]           5,626
│    └─InvertedResidual: 2-5             [1, 116, 28, 28]          --
│    │    └─Sequential: 3-3              [1, 58, 28, 28]           7,598
│    └─InvertedResidual: 2-6             [1, 116, 28, 28]          --
│    │    └─Sequential: 3-4              [1, 58, 28, 28]           7,598
│    └─InvertedResidual: 2-7             [1, 116, 28, 28]          --
│    │    └─Sequential: 3-5              [1, 58, 28, 28]           7,598
├─Sequential: 1-4                        [1, 232, 14, 14]          --
│    └─InvertedResidual: 2-8             [1, 232, 14, 14]          --
│    │    └─Sequential: 3-6              [1, 116, 14, 14]          14,964
│    │    └─Sequential: 3-7              [1, 116, 14, 14]          28,652
│    └─InvertedResidual: 2-9             [1, 232, 14, 14]          --
│    │    └─Sequential: 3-8              [1, 116, 14, 14]          28,652
│    └─InvertedResidual: 2-10            [1, 232, 14, 14]          --
│    │    └─Sequential: 3-9              [1, 116, 14, 14]          28,652
│    └─InvertedResidual: 2-11            [1, 232, 14, 14]          --
│    │    └─Sequential: 3-10             [1, 116, 14, 14]          28,652
│    └─InvertedResidual: 2-12            [1, 232, 14, 14]          --
│    │    └─Sequential: 3-11             [1, 116, 14, 14]          28,652
│    └─InvertedResidual: 2-13            [1, 232, 14, 14]          --
│    │    └─Sequential: 3-12             [1, 116, 14, 14]          28,652
│    └─InvertedResidual: 2-14            [1, 232, 14, 14]          --
│    │    └─Sequential: 3-13             [1, 116, 14, 14]          28,652
│    └─InvertedResidual: 2-15            [1, 232, 14, 14]          --
│    │    └─Sequential: 3-14             [1, 116, 14, 14]          28,652
├─Sequential: 1-5                        [1, 464, 7, 7]            --
│    └─InvertedResidual: 2-16            [1, 464, 7, 7]            --
│    │    └─Sequential: 3-15             [1, 232, 7, 7]            56,840
│    │    └─Sequential: 3-16             [1, 232, 7, 7]            111,128
│    └─InvertedResidual: 2-17            [1, 464, 7, 7]            --
│    │    └─Sequential: 3-17             [1, 232, 7, 7]            111,128
│    └─InvertedResidual: 2-18            [1, 464, 7, 7]            --
│    │    └─Sequential: 3-18             [1, 232, 7, 7]            111,128
│    └─InvertedResidual: 2-19            [1, 464, 7, 7]            --
│    │    └─Sequential: 3-19             [1, 232, 7, 7]            111,128
├─Sequential: 1-6                        [1, 1024, 7, 7]           --
│    └─Conv2d: 2-20                      [1, 1024, 7, 7]           475,136
│    └─BatchNorm2d: 2-21                 [1, 1024, 7, 7]           2,048
│    └─ReLU: 2-22                        [1, 1024, 7, 7]           --
├─Linear: 1-7                            [1, 1000]                 1,025,000
==========================================================================================
Total params: 2,278,604
Trainable params: 2,278,604
Non-trainable params: 0
Total mult-adds (M): 144.93
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 31.21
Params size (MB): 9.11
Estimated Total Size (MB): 40.93
==========================================================================================

For make sure I test with this code too.

from torchinfo import summary
import torchvision
import timm
num_classes = 2
model = timm.create_model('resnet101', pretrained=True, num_classes=num_classes)

summary(model,input_size=(1, 3, 224, 224), mode='eval', device = 'cuda:2')

and this make result like

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [1, 2]                    --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 256, 56, 56]          --
│    └─Bottleneck: 2-1                   [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           4,096
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128
│    │    └─Identity: 3-6                [1, 64, 56, 56]           --
│    │    └─ReLU: 3-7                    [1, 64, 56, 56]           --
│    │    └─Identity: 3-8                [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-9                  [1, 256, 56, 56]          16,384
│    │    └─BatchNorm2d: 3-10            [1, 256, 56, 56]          512
│    │    └─Sequential: 3-11             [1, 256, 56, 56]          16,896
│    │    └─ReLU: 3-12                   [1, 256, 56, 56]          --
│    └─Bottleneck: 2-2                   [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-13                 [1, 64, 56, 56]           16,384
│    │    └─BatchNorm2d: 3-14            [1, 64, 56, 56]           128
│    │    └─ReLU: 3-15                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-16                 [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-17            [1, 64, 56, 56]           128
│    │    └─Identity: 3-18               [1, 64, 56, 56]           --
│    │    └─ReLU: 3-19                   [1, 64, 56, 56]           --
│    │    └─Identity: 3-20               [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-21                 [1, 256, 56, 56]          16,384
│    │    └─BatchNorm2d: 3-22            [1, 256, 56, 56]          512
│    │    └─ReLU: 3-23                   [1, 256, 56, 56]          --
│    └─Bottleneck: 2-3                   [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-24                 [1, 64, 56, 56]           16,384
│    │    └─BatchNorm2d: 3-25            [1, 64, 56, 56]           128
│    │    └─ReLU: 3-26                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-27                 [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-28            [1, 64, 56, 56]           128
│    │    └─Identity: 3-29               [1, 64, 56, 56]           --
│    │    └─ReLU: 3-30                   [1, 64, 56, 56]           --
│    │    └─Identity: 3-31               [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-32                 [1, 256, 56, 56]          16,384
│    │    └─BatchNorm2d: 3-33            [1, 256, 56, 56]          512
│    │    └─ReLU: 3-34                   [1, 256, 56, 56]          --
├─Sequential: 1-6                        [1, 512, 28, 28]          --
│    └─Bottleneck: 2-4                   [1, 512, 28, 28]          --
│    │    └─Conv2d: 3-35                 [1, 128, 56, 56]          32,768
│    │    └─BatchNorm2d: 3-36            [1, 128, 56, 56]          256
│    │    └─ReLU: 3-37                   [1, 128, 56, 56]          --
│    │    └─Conv2d: 3-38                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-39            [1, 128, 28, 28]          256
│    │    └─Identity: 3-40               [1, 128, 28, 28]          --
│    │    └─ReLU: 3-41                   [1, 128, 28, 28]          --
│    │    └─Identity: 3-42               [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-43                 [1, 512, 28, 28]          65,536
│    │    └─BatchNorm2d: 3-44            [1, 512, 28, 28]          1,024
│    │    └─Sequential: 3-45             [1, 512, 28, 28]          132,096
│    │    └─ReLU: 3-46                   [1, 512, 28, 28]          --
│    └─Bottleneck: 2-5                   [1, 512, 28, 28]          --
│    │    └─Conv2d: 3-47                 [1, 128, 28, 28]          65,536
│    │    └─BatchNorm2d: 3-48            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-49                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-50                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-51            [1, 128, 28, 28]          256
│    │    └─Identity: 3-52               [1, 128, 28, 28]          --
│    │    └─ReLU: 3-53                   [1, 128, 28, 28]          --
│    │    └─Identity: 3-54               [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-55                 [1, 512, 28, 28]          65,536
│    │    └─BatchNorm2d: 3-56            [1, 512, 28, 28]          1,024
│    │    └─ReLU: 3-57                   [1, 512, 28, 28]          --
│    └─Bottleneck: 2-6                   [1, 512, 28, 28]          --
│    │    └─Conv2d: 3-58                 [1, 128, 28, 28]          65,536
│    │    └─BatchNorm2d: 3-59            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-60                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-61                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-62            [1, 128, 28, 28]          256
│    │    └─Identity: 3-63               [1, 128, 28, 28]          --
│    │    └─ReLU: 3-64                   [1, 128, 28, 28]          --
│    │    └─Identity: 3-65               [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-66                 [1, 512, 28, 28]          65,536
│    │    └─BatchNorm2d: 3-67            [1, 512, 28, 28]          1,024
│    │    └─ReLU: 3-68                   [1, 512, 28, 28]          --
│    └─Bottleneck: 2-7                   [1, 512, 28, 28]          --
│    │    └─Conv2d: 3-69                 [1, 128, 28, 28]          65,536
│    │    └─BatchNorm2d: 3-70            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-71                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-72                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-73            [1, 128, 28, 28]          256
│    │    └─Identity: 3-74               [1, 128, 28, 28]          --
│    │    └─ReLU: 3-75                   [1, 128, 28, 28]          --
│    │    └─Identity: 3-76               [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-77                 [1, 512, 28, 28]          65,536
│    │    └─BatchNorm2d: 3-78            [1, 512, 28, 28]          1,024
│    │    └─ReLU: 3-79                   [1, 512, 28, 28]          --
├─Sequential: 1-7                        [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-8                   [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-80                 [1, 256, 28, 28]          131,072
│    │    └─BatchNorm2d: 3-81            [1, 256, 28, 28]          512
│    │    └─ReLU: 3-82                   [1, 256, 28, 28]          --
│    │    └─Conv2d: 3-83                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-84            [1, 256, 14, 14]          512
│    │    └─Identity: 3-85               [1, 256, 14, 14]          --
│    │    └─ReLU: 3-86                   [1, 256, 14, 14]          --
│    │    └─Identity: 3-87               [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-88                 [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-89            [1, 1024, 14, 14]         2,048
│    │    └─Sequential: 3-90             [1, 1024, 14, 14]         526,336
│    │    └─ReLU: 3-91                   [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-9                   [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-92                 [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-93            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-94                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-95                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-96            [1, 256, 14, 14]          512
│    │    └─Identity: 3-97               [1, 256, 14, 14]          --
│    │    └─ReLU: 3-98                   [1, 256, 14, 14]          --
│    │    └─Identity: 3-99               [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-100                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-101           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-102                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-10                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-103                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-104           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-105                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-106                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-107           [1, 256, 14, 14]          512
│    │    └─Identity: 3-108              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-109                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-110              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-111                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-112           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-113                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-11                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-114                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-115           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-116                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-117                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-118           [1, 256, 14, 14]          512
│    │    └─Identity: 3-119              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-120                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-121              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-122                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-123           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-124                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-12                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-125                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-126           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-127                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-128                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-129           [1, 256, 14, 14]          512
│    │    └─Identity: 3-130              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-131                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-132              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-133                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-134           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-135                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-13                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-136                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-137           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-138                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-139                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-140           [1, 256, 14, 14]          512
│    │    └─Identity: 3-141              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-142                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-143              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-144                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-145           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-146                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-14                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-147                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-148           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-149                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-150                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-151           [1, 256, 14, 14]          512
│    │    └─Identity: 3-152              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-153                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-154              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-155                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-156           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-157                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-15                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-158                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-159           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-160                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-161                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-162           [1, 256, 14, 14]          512
│    │    └─Identity: 3-163              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-164                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-165              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-166                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-167           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-168                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-16                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-169                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-170           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-171                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-172                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-173           [1, 256, 14, 14]          512
│    │    └─Identity: 3-174              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-175                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-176              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-177                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-178           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-179                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-17                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-180                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-181           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-182                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-183                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-184           [1, 256, 14, 14]          512
│    │    └─Identity: 3-185              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-186                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-187              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-188                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-189           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-190                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-18                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-191                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-192           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-193                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-194                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-195           [1, 256, 14, 14]          512
│    │    └─Identity: 3-196              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-197                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-198              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-199                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-200           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-201                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-19                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-202                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-203           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-204                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-205                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-206           [1, 256, 14, 14]          512
│    │    └─Identity: 3-207              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-208                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-209              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-210                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-211           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-212                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-20                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-213                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-214           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-215                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-216                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-217           [1, 256, 14, 14]          512
│    │    └─Identity: 3-218              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-219                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-220              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-221                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-222           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-223                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-21                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-224                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-225           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-226                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-227                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-228           [1, 256, 14, 14]          512
│    │    └─Identity: 3-229              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-230                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-231              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-232                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-233           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-234                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-22                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-235                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-236           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-237                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-238                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-239           [1, 256, 14, 14]          512
│    │    └─Identity: 3-240              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-241                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-242              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-243                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-244           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-245                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-23                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-246                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-247           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-248                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-249                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-250           [1, 256, 14, 14]          512
│    │    └─Identity: 3-251              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-252                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-253              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-254                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-255           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-256                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-24                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-257                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-258           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-259                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-260                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-261           [1, 256, 14, 14]          512
│    │    └─Identity: 3-262              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-263                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-264              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-265                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-266           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-267                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-25                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-268                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-269           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-270                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-271                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-272           [1, 256, 14, 14]          512
│    │    └─Identity: 3-273              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-274                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-275              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-276                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-277           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-278                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-26                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-279                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-280           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-281                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-282                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-283           [1, 256, 14, 14]          512
│    │    └─Identity: 3-284              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-285                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-286              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-287                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-288           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-289                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-27                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-290                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-291           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-292                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-293                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-294           [1, 256, 14, 14]          512
│    │    └─Identity: 3-295              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-296                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-297              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-298                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-299           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-300                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-28                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-301                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-302           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-303                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-304                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-305           [1, 256, 14, 14]          512
│    │    └─Identity: 3-306              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-307                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-308              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-309                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-310           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-311                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-29                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-312                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-313           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-314                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-315                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-316           [1, 256, 14, 14]          512
│    │    └─Identity: 3-317              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-318                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-319              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-320                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-321           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-322                  [1, 1024, 14, 14]         --
│    └─Bottleneck: 2-30                  [1, 1024, 14, 14]         --
│    │    └─Conv2d: 3-323                [1, 256, 14, 14]          262,144
│    │    └─BatchNorm2d: 3-324           [1, 256, 14, 14]          512
│    │    └─ReLU: 3-325                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-326                [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-327           [1, 256, 14, 14]          512
│    │    └─Identity: 3-328              [1, 256, 14, 14]          --
│    │    └─ReLU: 3-329                  [1, 256, 14, 14]          --
│    │    └─Identity: 3-330              [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-331                [1, 1024, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-332           [1, 1024, 14, 14]         2,048
│    │    └─ReLU: 3-333                  [1, 1024, 14, 14]         --
├─Sequential: 1-8                        [1, 2048, 7, 7]           --
│    └─Bottleneck: 2-31                  [1, 2048, 7, 7]           --
│    │    └─Conv2d: 3-334                [1, 512, 14, 14]          524,288
│    │    └─BatchNorm2d: 3-335           [1, 512, 14, 14]          1,024
│    │    └─ReLU: 3-336                  [1, 512, 14, 14]          --
│    │    └─Conv2d: 3-337                [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-338           [1, 512, 7, 7]            1,024
│    │    └─Identity: 3-339              [1, 512, 7, 7]            --
│    │    └─ReLU: 3-340                  [1, 512, 7, 7]            --
│    │    └─Identity: 3-341              [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-342                [1, 2048, 7, 7]           1,048,576
│    │    └─BatchNorm2d: 3-343           [1, 2048, 7, 7]           4,096
│    │    └─Sequential: 3-344            [1, 2048, 7, 7]           2,101,248
│    │    └─ReLU: 3-345                  [1, 2048, 7, 7]           --
│    └─Bottleneck: 2-32                  [1, 2048, 7, 7]           --
│    │    └─Conv2d: 3-346                [1, 512, 7, 7]            1,048,576
│    │    └─BatchNorm2d: 3-347           [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-348                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-349                [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-350           [1, 512, 7, 7]            1,024
│    │    └─Identity: 3-351              [1, 512, 7, 7]            --
│    │    └─ReLU: 3-352                  [1, 512, 7, 7]            --
│    │    └─Identity: 3-353              [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-354                [1, 2048, 7, 7]           1,048,576
│    │    └─BatchNorm2d: 3-355           [1, 2048, 7, 7]           4,096
│    │    └─ReLU: 3-356                  [1, 2048, 7, 7]           --
│    └─Bottleneck: 2-33                  [1, 2048, 7, 7]           --
│    │    └─Conv2d: 3-357                [1, 512, 7, 7]            1,048,576
│    │    └─BatchNorm2d: 3-358           [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-359                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-360                [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-361           [1, 512, 7, 7]            1,024
│    │    └─Identity: 3-362              [1, 512, 7, 7]            --
│    │    └─ReLU: 3-363                  [1, 512, 7, 7]            --
│    │    └─Identity: 3-364              [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-365                [1, 2048, 7, 7]           1,048,576
│    │    └─BatchNorm2d: 3-366           [1, 2048, 7, 7]           4,096
│    │    └─ReLU: 3-367                  [1, 2048, 7, 7]           --
├─SelectAdaptivePool2d: 1-9              [1, 2048]                 --
│    └─AdaptiveAvgPool2d: 2-34           [1, 2048, 1, 1]           --
│    └─Flatten: 2-35                     [1, 2048]                 --
├─Linear: 1-10                           [1, 2]                    4,098
==========================================================================================
Total params: 42,504,258
Trainable params: 42,504,258
Non-trainable params: 0
Total mult-adds (G): 7.80
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 259.71
Params size (MB): 170.02
Estimated Total Size (MB): 430.33
==========================================================================================

I think timm library did not use the jit compiling at default.

mert-kurttutan commented 1 year ago

Just checked the model,

If you run the following code to see the running mean of first batch_norm, you will see that it is not empty

model = torchvision.models.shufflenet_v2_x1_0(pretrained=True,)
model = model.eval().to('cuda')
model.conv1[1].running_mean

I also ran torch Profile (on the same shufflenet model), which shows cuda memory used by inference is not as high as 1100 MB

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference         0.24%       4.989ms       100.00%        2.039s        2.039s       0.000us         0.00%       2.999ms       2.999ms          -4 b    -588.26 Kb           0 b     -22.60 Mb             1  
                                           aten::conv2d         0.01%     159.000us        99.15%        2.022s      36.103ms       0.000us         0.00%       2.116ms      37.786us           0 b           0 b       7.45 Mb           0 b            56  
                                      aten::convolution         0.02%     471.000us        99.14%        2.022s      36.101ms       0.000us         0.00%       2.116ms      37.786us           0 b           0 b       7.45 Mb           0 b            56  
                                     aten::_convolution         0.02%     379.000us        99.11%        2.021s      36.092ms       0.000us         0.00%       2.116ms      37.786us           0 b           0 b       7.45 Mb           0 b            56  
                                aten::cudnn_convolution         4.60%      93.889ms        99.06%        2.020s      54.599ms       1.920ms        64.02%       1.920ms      51.892us           0 b           0 b       5.68 Mb     725.00 Kb            37  
                volta_scudnn_128x64_relu_interior_nn_v1         0.00%       0.000us         0.00%       0.000us       0.000us       1.486ms        49.55%       1.486ms      57.154us           0 b           0 b           0 b           0 b            26  
                                       aten::batch_norm         0.01%     109.000us         0.25%       5.060ms      90.357us       0.000us         0.00%     369.000us       6.589us           0 b           0 b       7.45 Mb           0 b            56  
                           aten::_batch_norm_impl_index         0.01%     186.000us         0.24%       4.951ms      88.411us       0.000us         0.00%     369.000us       6.589us           0 b           0 b       7.45 Mb           0 b            56  
                                 aten::cudnn_batch_norm         0.09%       1.871ms         0.23%       4.765ms      85.089us     369.000us        12.30%     369.000us       6.589us           0 b           0 b       7.45 Mb           0 b            56  
void cudnn::bn_fw_inf_1C11_kernel_NCHW<float, float,...         0.00%       0.000us         0.00%       0.000us       0.000us     369.000us        12.30%     369.000us       6.589us           0 b           0 b           0 b           0 b            56  
                                aten::_conv_depthwise2d         0.01%     229.000us         0.03%     634.000us      33.368us     196.000us         6.54%     196.000us      10.316us           0 b           0 b       1.77 Mb           0 b            19  
void at::native::(anonymous namespace)::conv_depthwi...         0.00%       0.000us         0.00%       0.000us       0.000us     196.000us         6.54%     196.000us      10.316us           0 b           0 b           0 b           0 b            19  
                                            aten::relu_         0.02%     387.000us         0.06%       1.241ms      33.541us       0.000us         0.00%     158.000us       4.270us           0 b           0 b           0 b           0 b            37  
                                       aten::clamp_min_         0.03%     553.000us         0.04%     854.000us      23.081us     158.000us         5.27%     158.000us       4.270us           0 b           0 b           0 b           0 b            37  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     158.000us         5.27%     158.000us       4.270us           0 b           0 b           0 b           0 b            37  
                                            aten::copy_         0.01%     260.000us         0.03%     659.000us      38.765us     155.000us         5.17%     155.000us       9.118us           0 b           0 b           0 b           0 b            17  
                                              aten::cat         0.02%     338.000us         0.02%     466.000us      29.125us     145.000us         4.83%     145.000us       9.062us           0 b           0 b       3.13 Mb       3.13 Mb            16  
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us     145.000us         4.83%     145.000us       9.062us           0 b           0 b           0 b           0 b            16  
void implicit_convolve_sgemm<float, float, 128, 5, 5...         0.00%       0.000us         0.00%       0.000us       0.000us     140.000us         4.67%     140.000us      20.000us           0 b           0 b           0 b           0 b             7  
void cask_cudnn::computeOffsetsKernel<false, false>(...         0.00%       0.000us         0.00%       0.000us       0.000us     117.000us         3.90%     117.000us       4.034us           0 b           0 b           0 b           0 b            29  
                                       aten::contiguous         0.00%      26.000us         0.03%     652.000us      40.750us       0.000us         0.00%     103.000us       6.438us           0 b           0 b       3.13 Mb           0 b            16  
                                            aten::clone         0.00%      87.000us         0.03%     626.000us      39.125us       0.000us         0.00%     103.000us       6.438us           0 b           0 b       3.13 Mb           0 b            16  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     103.000us         3.43%     103.000us       6.438us           0 b           0 b           0 b           0 b            16  
volta_scudnn_128x32_sliced1x4_ldg4_relu_exp_interior...         0.00%       0.000us         0.00%       0.000us       0.000us      75.000us         2.50%      75.000us      75.000us           0 b           0 b           0 b           0 b             1  
                                               aten::to         0.07%       1.348ms         0.08%       1.639ms       1.639ms       0.000us         0.00%      52.000us      52.000us           0 b           0 b     588.00 Kb           0 b             1  
                                         aten::_to_copy         0.00%      19.000us         0.01%     291.000us     291.000us       0.000us         0.00%      52.000us      52.000us           0 b           0 b     588.00 Kb           0 b             1  
                       Memcpy HtoD (Pageable -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      52.000us         1.73%      52.000us      52.000us           0 b           0 b           0 b           0 b             1  
                volta_scudnn_128x32_relu_interior_nn_v1         0.00%       0.000us         0.00%       0.000us       0.000us      30.000us         1.00%      30.000us      30.000us           0 b           0 b           0 b           0 b             1  
                                           aten::linear         0.00%      22.000us         0.05%     921.000us     921.000us       0.000us         0.00%      30.000us      30.000us           0 b           0 b       4.00 Kb           0 b             1  
                                            aten::addmm         0.03%     645.000us         0.04%     886.000us     886.000us      30.000us         1.00%      30.000us      30.000us           0 b           0 b       4.00 Kb       4.00 Kb             1  
                  volta_scudnn_128x32_relu_medium_nn_v1         0.00%       0.000us         0.00%       0.000us       0.000us      29.000us         0.97%      29.000us      29.000us           0 b           0 b           0 b           0 b             1  
std::enable_if<!(false), void>::type internal::gemvx...         0.00%       0.000us         0.00%       0.000us       0.000us      26.000us         0.87%      26.000us      26.000us           0 b           0 b           0 b           0 b             1  
                                        Memset (Device)         0.00%       0.000us         0.00%       0.000us       0.000us      17.000us         0.57%      17.000us       4.250us           0 b           0 b           0 b           0 b             4  
                                       aten::max_pool2d         0.00%      16.000us         0.01%     292.000us     292.000us       0.000us         0.00%      15.000us      15.000us           0 b           0 b     882.00 Kb           0 b             1  
                          aten::max_pool2d_with_indices         0.00%      74.000us         0.01%     276.000us     276.000us      15.000us         0.50%      15.000us      15.000us           0 b           0 b     882.00 Kb     882.00 Kb             1  
void at::native::(anonymous namespace)::max_pool_for...         0.00%       0.000us         0.00%       0.000us       0.000us      15.000us         0.50%      15.000us      15.000us           0 b           0 b           0 b           0 b             1  
void implicit_convolve_sgemm<float, float, 1024, 5, ...         0.00%       0.000us         0.00%       0.000us       0.000us      14.000us         0.47%      14.000us      14.000us           0 b           0 b           0 b           0 b             1  
                                             aten::mean         0.00%      62.000us         0.00%      76.000us      76.000us      11.000us         0.37%      11.000us      11.000us           0 b           0 b       4.00 Kb       4.00 Kb             1  
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us      11.000us         0.37%      11.000us      11.000us           0 b           0 b           0 b           0 b             1  
void cudnn::ops::nchwToNhwcKernel<float, float, floa...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         0.20%       6.000us       6.000us           0 b           0 b           0 b           0 b             1  
void cudnn::ops::nhwcToNchwKernel<float, float, floa...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         0.20%       6.000us       6.000us           0 b           0 b           0 b           0 b             1  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       4.000us         0.13%       4.000us       4.000us           0 b           0 b           0 b           0 b             1  
                                            aten::zeros         0.00%      31.000us         0.00%      79.000us      79.000us       0.000us         0.00%       0.000us       0.000us           4 b           0 b           0 b           0 b             1  
                                            aten::empty         0.09%       1.781ms         0.15%       2.995ms      10.017us       0.000us         0.00%       0.000us       0.000us     588.26 Kb     588.26 Kb      15.55 Mb      15.55 Mb           299  
                                            aten::zero_         0.00%       2.000us         0.00%       2.000us       2.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                                            aten::randn         0.00%      29.000us         0.08%       1.532ms       1.532ms       0.000us         0.00%       0.000us       0.000us     588.00 Kb           0 b           0 b           0 b             1  
                                          aten::normal_         0.06%       1.296ms         0.06%       1.296ms       1.296ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                                    aten::empty_strided         0.00%      24.000us         0.00%      24.000us      24.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b     588.00 Kb     588.00 Kb             1  
                                        cudaMemcpyAsync         0.01%     187.000us         0.01%     187.000us      93.500us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             2  
                                  cudaStreamSynchronize         0.00%      60.000us         0.00%      60.000us      60.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                                     cudaGetDeviceCount         0.00%       1.000us         0.00%       1.000us       1.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                                   cudaDriverGetVersion         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                                 cudaDeviceGetAttribute         0.00%       1.000us         0.00%       1.000us       0.031us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            32  
                                cudaGetDeviceProperties         0.01%     124.000us         0.01%     124.000us     124.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                              cudaStreamCreateWithFlags         9.23%     188.197ms         9.23%     188.197ms      23.525ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             8  
                       cudaDeviceGetStreamPriorityRange         0.00%       2.000us         0.00%       2.000us       0.021us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            94  
                           cudaStreamCreateWithPriority         0.01%     215.000us         0.01%     215.000us      53.750us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             4  
                                             cudaMalloc         0.10%       2.026ms         0.10%       2.026ms     106.632us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            19  
                                        cudaMemsetAsync         0.00%      52.000us         0.00%      52.000us      13.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             4  
                                          cudaHostAlloc         0.05%       1.074ms         0.05%       1.074ms       1.074ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                               cudaHostGetDevicePointer         0.00%       2.000us         0.00%       2.000us       2.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                                               cudaFree        42.57%     868.160ms        42.57%     868.160ms     217.040ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             4  
                                   cudaGetSymbolAddress         0.00%       1.000us         0.00%       1.000us       1.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                                  cudaStreamIsCapturing         0.00%      13.000us         0.00%      13.000us       0.127us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           102  
                                  cudaStreamGetPriority         0.00%       1.000us         0.00%       1.000us       0.011us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            93  
                                       cudaLaunchKernel        42.61%     868.871ms        42.61%     868.871ms       4.041ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           215  
                                       aten::empty_like         0.01%     188.000us         0.09%       1.780ms      24.722us       0.000us         0.00%       0.000us       0.000us           0 b           0 b      10.58 Mb           0 b            72  
                                             aten::view         0.01%     262.000us         0.01%     262.000us       2.977us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            88  
                                          aten::resize_         0.01%     132.000us         0.01%     132.000us       6.947us       0.000us         0.00%       0.000us       0.000us           0 b           0 b       1.77 Mb       1.77 Mb            19  
                                        aten::transpose         0.00%      98.000us         0.01%     127.000us       7.471us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            17  
                                       aten::as_strided         0.00%      45.000us         0.00%      45.000us       1.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            45  
                                            aten::chunk         0.00%      26.000us         0.01%     211.000us      16.231us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            13  
                                            aten::split         0.00%      97.000us         0.01%     185.000us      14.231us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            13  
                                           aten::narrow         0.00%      34.000us         0.00%      88.000us       3.385us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            26  
                                            aten::slice         0.00%      40.000us         0.00%      54.000us       2.077us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            26  
                                                aten::t         0.00%       9.000us         0.00%      13.000us      13.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                                           aten::expand         0.00%       3.000us         0.00%       3.000us       3.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFla...         0.00%       9.000us         0.00%       9.000us       9.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                                  cudaDeviceSynchronize         0.00%      10.000us         0.00%      10.000us      10.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.039s
Self CUDA time total: 2.999ms

I also observed that even loading model it occupies around 600MB cuda memory. I think around this range memory usage is dominated by CUDA itself. This would be my best guess.