sksq96 / pytorch-summary

Model summary in PyTorch similar to `model.summary()` in Keras
MIT License
3.98k stars 412 forks source link

fix: Fixed hook for DenseNet #116

Open frgfm opened 4 years ago

frgfm commented 4 years ago

Hello there,

I found recently that the main function does not work on the torchvision implementation of DenseNet.

from torchvision.models import densenet121
from torchsummary import summary
model = densenet121().eval().cuda()
summary(model, (3, 224, 224))

would yield

~/Documents/pytorch-summary/torchsummary/torchsummary.py in hook(module, input, output)
     36             m_key = "%s-%i" % (class_name, module_idx + 1)
     37             summary[m_key] = OrderedDict()
---> 38             summary[m_key]["input_shape"] = list(input[0].size())
     39             summary[m_key]["input_shape"][0] = batch_size
     40             if isinstance(output, (list, tuple)):

AttributeError: 'list' object has no attribute 'size'

The reason behind this is the _DenseLayer type in the architecture. Since it does not inherit from torch.nn.Sequential or torch.nn.ModuleList, the element is not skipped when hooking all object children.

This PR fixes it by checking if the current child has any children modules, before registering ahook.

Hope this helps! Cheers

leriomaggio commented 4 years ago

👍

kaixinbear commented 4 years ago

Hi,I have install your branch using pip install git+https://github.com/frgfm/pytorch-summary.git, but I still meet this issue "AttributeError: 'list' object has no attribute 'size'" when testing the densenet backbone.Could you do me a favor?

frgfm commented 4 years ago

Hi @kaixinbear, Thanks for pointing it out. Could you provide the exact running code you used to reproduce the error? Also, just checking, when did you add flags / options to pip install git+ ? By default, it would install the master branch. But since this is a fix/feature, my edit is on the densenet-fix branch (cf. the PR header).

Cheers!

kaixinbear commented 4 years ago

just now,I rerunpip install git+https://github.com/frgfm/pytorch-summary.git

Collecting git+https://github.com/frgfm/pytorch-summary.git
  Cloning https://github.com/frgfm/pytorch-summary.git to /tmp/pip-req-build-9fh9x80m
  Running command git clone -q https://github.com/frgfm/pytorch-summary.git /tmp/pip-req-build-9fh9x80m
Requirement already satisfied (use --upgrade to upgrade): torchsummary==1.5.1 from git+https://github.com/frgfm/pytorch-summary.git in ./anaconda3/lib/python3.7/site-packages
Building wheels for collected packages: torchsummary
  Building wheel for torchsummary (setup.py) ... done
  Created wheel for torchsummary: filename=torchsummary-1.5.1-cp37-none-any.whl size=2850 sha256=1011ad3c5742d616c11a7b211663761d4ce739545519d23b0a4e62fcf8289120
  Stored in directory: /tmp/pip-ephem-wheel-cache-3k2g7tqz/wheels/27/2e/88/8c4cca542c91043b0d6c5ca666e59a6abdeb8fe05e2a198db8
$ python
Python 3.7.4 (default, Aug 13 2019, 20:35:49) 
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torchsummary
>>> from torchvision.models import densenet121
>>> from torchsummary import summary
>>> model = densenet121().eval().cuda()
>>> summary(model, (3, 224, 224))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torchsummary/torchsummary.py", line 72, in summary
    model(*x)
  File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torchvision/models/densenet.py", line 194, in forward
    features = self.features(x)
  File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 100, in forward
    input = module(input)
  File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torchvision/models/densenet.py", line 111, in forward
    new_features = layer(features)
  File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 534, in __call__
    hook_result = hook(self, input, result)
  File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torchsummary/torchsummary.py", line 19, in hook
    summary[m_key]["input_shape"] = list(input[0].size())
AttributeError: 'list' object has no attribute 'size'
frgfm commented 4 years ago

Thanks for the code @kaixinbear ! But as I suggested previously, the command pip install git+https://github.com/frgfm/pytorch-summary.git will install the master branch, which is in every point identical to the master branch of the original repo.

If you wish to check this PR modifications, you need to install from my densenet-fix branch. In order to do so, first, uninstall any existing installation and then

pip install git+https://github.com/frgfm/pytorch-summary@densenet-fix

Then run your code again, and it should be working

If not, please paste your torch version and your OS.

kaixinbear commented 4 years ago

Well,I got it.Thanks for your kind reply.

GregorKerr1996 commented 4 years ago

Hi there, i do the following commands however still end up with the error as discussed above,

!pip install git+https://github.com/frgfm/pytorch-summary@densenet-fix from torchvision.models import densenet121 from torchsummary import summary model = densenet121().eval().cuda() summary(model, (3, 224, 224))

Any help would be appreciated

frgfm commented 4 years ago

Hello @GregorKerr1996, Thanks for reporting it! My apologies, to properly install this PR content you should do the following:

pip uninstall torchsummary
git clone https://github.com/frgfm/pytorch-summary.git
cd pytorch-summary && git checkout densenet-fix
pip install -e .

Let me know if the error persists!

frgfm commented 4 years ago

Besides, for personal use, I made a python library of my own adding ops estimations if you are interested. I benchmarked my implementation against torchvision models and the results are similar to other OPs estimation libraries.

Here it is: https://github.com/frgfm/torch-scan

harshraj22 commented 3 years ago

@sksq96 @Naireen Summary with DenseNet has been a well known issue since long time. Could you people please check out this PR, and if it looks good, maybe merge it as well ? I'd love to get the summary from library installed through pip, than manually installing it over git, and switching to this branch.

frgfm commented 3 years ago

@harshraj22 this repo doesn't seem to be maintained anymore but feel free to check out this: https://github.com/frgfm/torch-scan There are other features included