mert-kurttutan / torchview

torchview: visualize pytorch models
https://torchview.dev
MIT License
822 stars 37 forks source link

Allow ndarrays in process_inputs #105

Open Erotemic opened 1 year ago

Erotemic commented 1 year ago

Currently if the input to the model forward pass contains an ndarray it is treated as an iterable causing the error:

File ~/code/torchview/torchview/torchview.py:372, in traverse_data(data, action_fn, aggregate_fn)
    368     return aggregate(
    369         *(traverse_data(d, action_fn, aggregate_fn) for d in data)
    370     )
    371 if isinstance(data, Iterable) and not isinstance(data, str):
--> 372     return aggregate(
    373         [traverse_data(d, action_fn, aggregate_fn) for d in data]
    374     )
    375 # Data is neither a tensor nor a collection
    376 return data

TypeError: 'numpy.float64' object cannot be interpreted as an integer

This patch fixes it by simply letting ndarrays pass through as-is. Associated tests are added.

vercel[bot] commented 1 year ago

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Comments Updated (UTC)
torchview ✅ Ready (Inspect) Visit Preview 💬 Add feedback Sep 2, 2023 10:55pm