Currently if the input to the model forward pass contains an ndarray it is treated as an iterable causing the error:
File ~/code/torchview/torchview/torchview.py:372, in traverse_data(data, action_fn, aggregate_fn)
368 return aggregate(
369 *(traverse_data(d, action_fn, aggregate_fn) for d in data)
370 )
371 if isinstance(data, Iterable) and not isinstance(data, str):
--> 372 return aggregate(
373 [traverse_data(d, action_fn, aggregate_fn) for d in data]
374 )
375 # Data is neither a tensor nor a collection
376 return data
TypeError: 'numpy.float64' object cannot be interpreted as an integer
This patch fixes it by simply letting ndarrays pass through as-is. Associated tests are added.
Currently if the input to the model forward pass contains an ndarray it is treated as an iterable causing the error:
This patch fixes it by simply letting ndarrays pass through as-is. Associated tests are added.