TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.56k stars 119 forks source link

fails on model returning ndarray(s) #152

Closed smidm closed 1 year ago

smidm commented 2 years ago

Describe the bug

torchinfo.summary() fails on a model that returns numpy ndarray instead of tensors.

To Reproduce

first install https://mmpose.readthedocs.io

from pathlib import Path
import mim
import torchinfo
import torch
from mmpose.apis import init_pose_model

# load model
mmpose_model = 'topdown_heatmap_mspn50_coco_256x192'
models_dir = Path('models')
models_dir.mkdir(exist_ok=True)
mmpose_pth_filename = mim.commands.download('mmpose', [mmpose_model], models_dir)[0]
mmpose_model = init_pose_model(
    str(models_dir / (mmpose_model + '.py')), 
    str(models_dir / mmpose_pth_filename), 
    device='cpu'
)

sample_input_data = dict(
    img=torch.rand(1, 3, 192, 256),
    img_metas=[dict(image_file='xxx', center=(2, 4), scale=(1, 1), rotation=0, bbox_score=0.9,
            flip_pairs=[[1, 2], [3, 4]])],
    return_loss=False,
    return_heatmap=False)
torchinfo.summary(mmpose_model, input_data=sample_input_data, mode='eval')

the code above results in:

AttributeError                            Traceback (most recent call last)
~/anaconda3/envs/mm/lib/python3.7/site-packages/torchinfo/torchinfo.py in forward_pass(model, x, batch_dim, cache_forward_pass, device, mode, **kwargs)
    291             elif isinstance(x, dict):
--> 292                 _ = model.to(device)(**x, **kwargs)
    293             else:

~/anaconda3/envs/mm/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1130             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
-> 1131                 hook_result = hook(self, input, result)
   1132                 if hook_result is not None:

~/anaconda3/envs/mm/lib/python3.7/site-packages/torchinfo/torchinfo.py in hook(module, inputs, outputs)
    540         info.input_size, _ = info.calculate_size(inputs, batch_dim)
--> 541         info.output_size, elem_bytes = info.calculate_size(outputs, batch_dim)
    542         info.output_bytes = elem_bytes * prod(info.output_size)

~/anaconda3/envs/mm/lib/python3.7/site-packages/torchinfo/layer_info.py in calculate_size(inputs, batch_dim)
    129             size = []
--> 130             elem_bytes = list(inputs.values())[0].element_size()
    131             for _, output in inputs.items():

AttributeError: 'numpy.ndarray' object has no attribute 'element_size'

model output:

{'preds': array([[[-24.041672  ,  82.90625   ,   0.50608665],
        [-21.958336  ,  79.78125   ,   0.50380564],
        [-24.041672  ,  79.78125   ,   0.5038374 ],
        [-44.875004  ,  78.21875   ,   0.504611  ],
        [-49.041668  , -78.03125   ,   0.50395286],
        [ 72.83333   , -92.875     ,   0.50649935],
        [-82.375     , -23.34375   ,   0.50320053],
        [-78.208336  , -60.84375   ,   0.50331956],
        [-76.125     , -60.84375   ,   0.50289243],
        [ 78.04166   , -88.96875   ,   0.51409924],
        [-86.541664  , -88.96875   ,   0.51163995],
        [-19.875     ,  95.40625   ,   0.5046092 ],
        [-28.208336  ,  90.71875   ,   0.504157  ],
        [-55.291668  ,  45.40625   ,   0.5035337 ],
        [ 44.70833   ,  45.40625   ,   0.50295764],
        [  3.0416641 , -84.28125   ,   0.5037741 ],
        [  3.0416641 , -84.28125   ,   0.5039975 ]]], dtype=float32), 'boxes': array([[2.e+00, 4.e+00, 1.e+00, 1.e+00, 4.e+04, 9.e-01]], dtype=float32), 'image_paths': ['xxx'], 'bbox_ids': None, 'output_heatmap': None}

torchinfo 1.7.0

TylerYep commented 2 years ago

Thanks for reporting this issue. numpy arrays and size functionality is not supported in torchinfo at the moment. PRs to add support for this output type are much appreciated!