TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.5k stars 118 forks source link

calculate size error when batch_dim is not None and input is 0 dim tensor #38

Closed zmzhang2000 closed 3 years ago

zmzhang2000 commented 3 years ago

I use a module with output of 0-dim tensor in the model class. When the batch_dim is not None, an IndexError will be raised.

For example:

import torch
import torch.nn as nn
from torchinfo import summary

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, x, label):
        loss = self.criterion(x, label)
        return loss

if __name__ == "__main__":
    model = Model()
    x = torch.rand(32, 100, dtype=torch.float)
    label = torch.randint(0, 2, (32, 100), dtype=torch.float)
    summary(model, input_data=(x, label), batch_dim=0)
Traceback (most recent call last):
  File "D:\Zhang\Codes\torchinfo\torchinfo\torchinfo.py", line 175, in summary
    _ = model.to(device)(*x, **kwargs)
  File "D:\Zhang\Codes\torchinfo\venv\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "D:/Zhang/Codes/torchinfo/test.py", line 12, in forward
    loss = self.criterion(x, label)
  File "D:\Zhang\Codes\torchinfo\venv\lib\site-packages\torch\nn\modules\module.py", line 893, in _call_impl
    hook_result = hook(self, input, result)
  File "D:\Zhang\Codes\torchinfo\torchinfo\torchinfo.py", line 367, in hook
    info.output_size = info.calculate_size(outputs, batch_dim)
  File "D:\Zhang\Codes\torchinfo\torchinfo\layer_info.py", line 85, in calculate_size
    size[batch_dim] = 1
IndexError: list assignment index out of range

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "D:/Zhang/Codes/torchinfo/test.py", line 20, in <module>
    summary(model, input_data=(x, label), batch_dim=0)
  File "D:\Zhang\Codes\torchinfo\torchinfo\torchinfo.py", line 184, in summary
    raise RuntimeError(
RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []

I found that the size of the module output is [], which will raise the IndexError when

size[batch_dim]=1

I have created a PR #37 to fix the problem.

TylerYep commented 3 years ago

Thanks for the contribution! Merged the change, expect a new release sometime soon.