graspnet / graspnet-baseline

Baseline model for "GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping" (CVPR 2020)
https://graspnet.net/
Other
415 stars 133 forks source link

Network Summary with Dict as input #70

Closed 1mingW closed 11 months ago

1mingW commented 11 months ago

Hi,

I am trying to summarize the network as the following form to have an overview of its architecture and number of parameters:

================================================================================================================
Layer (type:depth-idx)          Input Shape          Output Shape         Param #            Mult-Adds
================================================================================================================
SingleInputNet                  [7, 1, 28, 28]       [7, 10]              --                 --
├─Conv2d: 1-1                   [7, 1, 28, 28]       [7, 10, 24, 24]      260                1,048,320
├─Conv2d: 1-2                   [7, 10, 12, 12]      [7, 20, 8, 8]        5,020              2,248,960
├─Dropout2d: 1-3                [7, 20, 8, 8]        [7, 20, 8, 8]        --                 --
├─Linear: 1-4                   [7, 320]             [7, 50]              16,050             112,350
├─Linear: 1-5                   [7, 50]              [7, 10]              510                3,570
================================================================================================================
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds (M): 3.41
================================================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 0.40
Params size (MB): 0.09
Estimated Total Size (MB): 0.51
================================================================================================================

I can print the network and it also shows the layers, however it doesn't give me other information. summary(model, ...) needs the size of input to the network. But I'm not sure how to define the size of a dictionary and this method seems not suitable for a dictionary input. I tried to use torchinfo https://github.com/TylerYep/torchinfo, but they require passing the dict as a list of args or a dict of kwargs to the forward() function. How did you manage to summarize it, when the network takes a dictionary as input? Or do you have any other methods to get the total number of parameters of the network?

chenxi-wang commented 11 months ago

You can try this: https://stackoverflow.com/a/62508086

1mingW commented 11 months ago

Thank you! This works for me.