sksq96 / pytorch-summary

Model summary in PyTorch similar to `model.summary()` in Keras
MIT License
4.01k stars 412 forks source link

LSTM cannot be used #143

Open z-a-f opened 4 years ago

z-a-f commented 4 years ago

If a model has an LSTM, this fails. I guess this is related to #130.

Minimum failing example:

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.lstm = nn.LSTM(5, 5)

    def forward(self, x):
        return self.lstm(x)

model = Model()
summary(model, (3, 5), device='cpu')

throws an error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-87-26dd936cf377> in <module>
      8 
      9 model = Model()
---> 10 summary(model, (3, 5), device='cpu')

~/miniconda3/envs/pytorch-dev/lib/python3.6/site-packages/torchsummary/torchsummary.py in summary(model, input_size, batch_size, device)
     70     # make a forward pass
     71     # print(x.shape)
---> 72     model(*x)
     73 
     74     # remove these hooks

~/Git/pytorch-dev/pytorch/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-87-26dd936cf377> in forward(self, x)
      5 
      6     def forward(self, x):
----> 7         return self.lstm(x)
      8 
      9 model = Model()

~/Git/pytorch-dev/pytorch/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    729                 _global_forward_hooks.values(),
    730                 self._forward_hooks.values()):
--> 731             hook_result = hook(self, input, result)
    732             if hook_result is not None:
    733                 result = hook_result

~/miniconda3/envs/pytorch-dev/lib/python3.6/site-packages/torchsummary/torchsummary.py in hook(module, input, output)
     21             if isinstance(output, (list, tuple)):
     22                 summary[m_key]["output_shape"] = [
---> 23                     [-1] + list(o.size())[1:] for o in output
     24                 ]
     25             else:

~/miniconda3/envs/pytorch-dev/lib/python3.6/site-packages/torchsummary/torchsummary.py in <listcomp>(.0)
     21             if isinstance(output, (list, tuple)):
     22                 summary[m_key]["output_shape"] = [
---> 23                     [-1] + list(o.size())[1:] for o in output
     24                 ]
     25             else:

AttributeError: 'tuple' object has no attribute 'size'