TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.56k stars 119 forks source link

'Conv2d' object has no attribute 'weight_mask' #107

Closed alimoezzi closed 2 years ago

alimoezzi commented 2 years ago

Hi I'm getting an error for a simple vgg16 implementation

Traceback (most recent call last):
  File "/home/user/.local/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 272, in forward_pass
    _ = model.to(device)(*x, **kwargs)
  File "/home/user/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/Projects/net/vgg.py", line 97, in forward
    out = self.features(x)
  File "/home/user/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1109, in _call_impl
    result = hook(self, input)
  File "/home/user/.local/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 500, in pre_hook
    info.calculate_num_params()
  File "/home/user/.local/lib/python3.8/site-packages/torchinfo/layer_info.py", line 151, in calculate_num_params
    cur_params, name = self.get_param_count(name, param)
  File "/home/user/.local/lib/python3.8/site-packages/torchinfo/layer_info.py", line 139, in get_param_count
    torch.sum(rgetattr(self.module, f"{without_suffix}_mask"))
  File "/home/user/.local/lib/python3.8/site-packages/torchinfo/layer_info.py", line 19, in rgetattr
    module = getattr(module, attr_i)
  File "/home/user/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1177, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'Conv2d' object has no attribute 'weight_mask'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 3444, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-45-e9b0e1aa526c>", line 1, in <module>
    summary(v, (1, 3, 224, 224), depth=3, col_names=["input_size", "output_size", "kernel_size", "num_params"])
  File "/home/user/.local/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 201, in summary
    summary_list = forward_pass(
  File "/home/user/.local/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 281, in forward_pass
    raise RuntimeError(
RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []
TylerYep commented 2 years ago

Hi, can you post your full code so I can reproduce this error?

alimoezzi commented 2 years ago

@TylerYep I tried to get a summary for this vgg implementation

TylerYep commented 2 years ago

@MajorCarrot Seems like there is an issue with your implementation of the pruned models

MajorCarrot commented 2 years ago

Okay, so what I see is that the implementation of spectral normalization in this reference repo uses the weight_orig parameter similar to what PyTorch does for masked models. The bits I had added was using the _orig as a reference that the model is masked. One way to overcome this could be to remove this hardcoded dependency and allow the user to specify layers that are pruned (or that the model is pruned). This would be slightly janky but should work nonetheless.

TylerYep commented 2 years ago

Added a temporary fix in https://github.com/TylerYep/torchinfo/commit/564aeca9ce8ca3d59e71970a9f80d296797457fe but we should consider more robust ways of detecting pruned models. Will get a bug fix release out soon

TylerYep commented 2 years ago

Released a fix in v1.6.3. Can you try upgrading and seeing if this fixes your issue @realsarm ?

alimoezzi commented 2 years ago

@TylerYep Yes, as far as I tested, it is solved

TylerYep commented 2 years ago

Thanks for the bug report!