Open minostauros opened 1 year ago
This happens when an input of a model is a list of strings, e.g., language models.
Possible dirty workaround
def get_total_memory_used(data: CORRECTED_INPUT_DATA_TYPE) -> int:
"""Calculates the total memory of all tensors stored in data."""
result = traverse_input_data(
data,
action_fn=lambda data: sys.getsizeof(
data.untyped_storage()
if hasattr(data, "untyped_storage")
else data.storage()
),
aggregate_fn=(
# We don't need the dictionary keys in this case
# if the data is not integer, assume the above action_fn is not applied for some reason
(
lambda data: (
lambda d: sum(d.values())
if isinstance(d, Mapping)
else sys.getsizeof(d)
)
)
if (isinstance(data, Mapping) or not isinstance(data, int))
else sum
),
)
return cast(int, result)
https://github.com/TylerYep/torchinfo/blob/73ed5687acfd6199b77fa1dcb65aa54762c1b720/torchinfo/torchinfo.py#L501
action_fn
is not applied to str so that sys.getsizeof fails to get size of strings.