Swall0w / torchstat

Model analyzer in PyTorch
MIT License
1.45k stars 144 forks source link

itemsize = input[0].detach().numpy().itemsize AttributeError: 'list' object has no attribute 'detach' #44

Open Abdellah-Laassairi opened 1 year ago

Abdellah-Laassairi commented 1 year ago

Expected Behavior

Retrieving the statistics of the SMP model with FPN architecture and resnet34 encoder.

Actual Behavior

Input[0] is a list which causes an AttributeError.

Code to Reproduce the Problem

import torch
import segmentation_models_pytorch as smp
from torchstat import stat

model = smp.FPN(
    encoder_name="resnet34",       
    encoder_weights="image``net",    
    in_channels=3,                
    classes=3,                    
)

device = torch.device("cpu")

model = model.to(device) 
stat(model, (3, 224, 224))

Specifications:

 File "repreduce.py", line 16, in <module>
    stat(model, (3, 224, 224))
  File "/home/eouser/anaconda3/envs/mlenv/lib/python3.8/site-packages/torchstat/statistics.py", line 71, in stat
    ms.show_report()
  File "/home/eouser/anaconda3/envs/mlenv/lib/python3.8/site-packages/torchstat/statistics.py", line 64, in show_report
    collected_nodes = self._analyze_model()
  File "/home/eouser/anaconda3/envs/mlenv/lib/python3.8/site-packages/torchstat/statistics.py", line 57, in _analyze_model
    model_hook = ModelHook(self._model, self._input_size)
  File "/home/eouser/anaconda3/envs/mlenv/lib/python3.8/site-packages/torchstat/model_hook.py", line 24, in __init__
    self._model(x)
  File "/home/eouser/anaconda3/envs/mlenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/eouser/anaconda3/envs/mlenv/lib/python3.8/site-packages/segmentation_models_pytorch/base/model.py", line 30, in forward
    decoder_output = self.decoder(*features)
  File "/home/eouser/anaconda3/envs/mlenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/eouser/anaconda3/envs/mlenv/lib/python3.8/site-packages/segmentation_models_pytorch/decoders/fpn/decoder.py", line 110, in forward
    x = self.merge(feature_pyramid)
  File "/home/eouser/anaconda3/envs/mlenv/lib/python3.8/site-packages/torchstat/model_hook.py", line 47, in wrap_call
    itemsize = input[0].detach().numpy().itemsize
AttributeError: 'list' object has no attribute 'detach'

Version: 0.0.7 Platform: Ubuntu 22.04.2 LTS

Fritingo commented 1 year ago

I try to modify model_hook.py

the input type is tuple and input[0] type is list, so it can't detach

change

itemsize = input[0].detach().numpy().itemsize

to

itemsize = 0
for i in range(len(input[0])):
    itemsize += input[0][i].detach().numpy().itemsize

and

module.input_shape = torch.from_numpy(
    np.array(input[0].size()[1:], dtype=np.int32))

to

module.input_shape = torch.from_numpy(
    np.array(input[0][0].size()[1:], dtype=np.int32))

is work to me