Stonesjtu / pytorch_memlab

Profiling and inspecting memory in pytorch
MIT License
1.01k stars 37 forks source link

`.report()` print issue with `torch.SymInt` / `torch.SymFloat` from `torch.compile`'d modules #60

Open kylevedder opened 1 month ago

kylevedder commented 1 month ago

Currently, readable_size() is defined to be

def readable_size(num_bytes: int) -> str:
    return '' if isnan(num_bytes) else '{:.2f}'.format(calmsize(num_bytes))

to inspect my error, I changed it to:

def readable_size(num_bytes: int) -> str:
    calm_res = calmsize(num_bytes)
    try:
        return '' if isnan(num_bytes) else '{:.2f}'.format(calm_res)
    except:
        breakpoint()

Where I have the following trace in PDB:

> /miniconda/lib/python3.11/site-packages/pytorch_memlab/utils.py(9)readable_size()->None
-> breakpoint()
(Pdb) num_bytes
512*ceiling(19*s0/128)
(Pdb) type(num_bytes)
<class 'torch.SymInt'>
(Pdb) calm_res
[rank0]:W0718 17:36:47.217000 140136914814784 torch/fx/experimental/symbolic_shapes.py:3991] Ignored guard Round(ceiling(19*s0/128)/2048) == 4, this could result in accuracy problems
4M<ByteSize amount=ceiling(19*s0/128)/2048>
(Pdb) type(calm_res)
<class 'calmsize.calmsize.ByteSize'>
(Pdb) '{:.2f}'.format(calm_res)
*** TypeError: unsupported format string passed to SymFloat.__format__

This appears to be caused by a lack of support for torch.SymInt by calmsize, which are a product of torch.compile'd modules.

Stonesjtu commented 1 month ago

Can you post a minimal reproduce snippet to get SymInt