TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.48k stars 117 forks source link

Detecting dynamic flow? #310

Open eliegoudout opened 3 months ago

eliegoudout commented 3 months ago

Hi,

I came across the notion of static/dynamic flow on PyTorch's doc. I realize that dynamic flow (that is when module calls may not be consistent between different inputs) poses an obvious problem for torchinfo. Indeed, the summary is computed through a forward pass of an input (random or full zeros or else, I've not looked at your code to figure it out), but another input might yield a different module execution.

As such, I think it would be wise to consider issuing a warning or raising an error when detecting a dynamic flow? Otherwise, the output may be misleading.

I chose the "Feature Request" tag, but it might also e considered a "Bug Report" since it's about an elementary vulnerability.

Cheers!

TylerYep commented 3 months ago

That sounds reasonable. PRs addressing this are welcome!

eliegoudout commented 3 months ago

I would think about adding something like this in summary

import warnings 

warning = """
The control flow of the target module may be dynamic. As such, the
summary may vary for different inputs. For more information, see
https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing
"""
try:
    torch.fx.symbolic_trace(model)  # Fails for dynamic control flow
except torch.fx.proxy.TraceError as e:
    warnings.warn(warning)

but I'm a bit uncomfortable pushing this like this because I'm not entirely sure about my understanding of the page I linked regarding control flow and tracing, my knowloedge is very limited on this. Also, I don't know the cost of using torch.fx.symbolic_trace(model) but I guess it could be roughly equivalent to a normal forward pass, so that could add a bit of an overhead.