mert-kurttutan / torchview

torchview: visualize pytorch models
https://torchview.dev
MIT License
834 stars 39 forks source link

No Support for FastRCNN Based Models #87

Open RoyiAvital opened 1 year ago

RoyiAvital commented 1 year ago

Describe the bug FastRCNN based models input is a list (The batch is a list of tensors).
Hence it can not be described using the current API which only let the user set the tensor dimensions.

To Reproduce

  1. Build the FastRCNN model:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
# For inference
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)
  1. The input can not be described using the input size API.

Expected behavior Being able to set the input using a specific example.

mert-kurttutan commented 1 year ago

Thanks for the report! It turns out model gives another error when using input_data API. The problem for this is because torchvision uses Imagelist object when return tensor from modules. As of now, torchview covers iterable, mappable objects, which ImageList is not.

Relevant piece of code: https://github.com/pytorch/vision/blob/6ca9c76adb6daf2695d603ad623a9cf1c4f4806f/torchvision/models/detection/generalized_rcnn.py#L83

My first attempt at solving this would be to iterate of attributes of output and input objects in order to cover all the objects, not just iterable and mappable ones.

mert-kurttutan commented 1 year ago

I will try a few solutions and inform you when a relevant update is complete

RoyiAvital commented 1 year ago

Great!

By the way, wrote about your package at StackExchange - Data Science - How Do You Visualize Neural Network Architectures.

mert-kurttutan commented 1 year ago

Thanks for spreading the word :)