I use a module with output of 0-dim tensor in the model class.
When the batch_dim is not None, an IndexError will be raised.
For example:
import torch
import torch.nn as nn
from torchinfo import summary
class Model(nn.Module):
def __init__(self):
super().__init__()
self.criterion = nn.BCEWithLogitsLoss()
def forward(self, x, label):
loss = self.criterion(x, label)
return loss
if __name__ == "__main__":
model = Model()
x = torch.rand(32, 100, dtype=torch.float)
label = torch.randint(0, 2, (32, 100), dtype=torch.float)
summary(model, input_data=(x, label), batch_dim=0)
Traceback (most recent call last):
File "D:\Zhang\Codes\torchinfo\torchinfo\torchinfo.py", line 175, in summary
_ = model.to(device)(*x, **kwargs)
File "D:\Zhang\Codes\torchinfo\venv\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:/Zhang/Codes/torchinfo/test.py", line 12, in forward
loss = self.criterion(x, label)
File "D:\Zhang\Codes\torchinfo\venv\lib\site-packages\torch\nn\modules\module.py", line 893, in _call_impl
hook_result = hook(self, input, result)
File "D:\Zhang\Codes\torchinfo\torchinfo\torchinfo.py", line 367, in hook
info.output_size = info.calculate_size(outputs, batch_dim)
File "D:\Zhang\Codes\torchinfo\torchinfo\layer_info.py", line 85, in calculate_size
size[batch_dim] = 1
IndexError: list assignment index out of range
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "D:/Zhang/Codes/torchinfo/test.py", line 20, in <module>
summary(model, input_data=(x, label), batch_dim=0)
File "D:\Zhang\Codes\torchinfo\torchinfo\torchinfo.py", line 184, in summary
raise RuntimeError(
RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []
I found that the size of the module output is [], which will raise the IndexError when
I use a module with output of 0-dim tensor in the model class. When the batch_dim is not None, an IndexError will be raised.
For example:
I found that the size of the module output is [], which will raise the IndexError when
I have created a PR #37 to fix the problem.