Open ssgosh opened 6 months ago
Sounds like a good idea. I personally helped people who had this issue.
This seems a bit tricky to implement.
Currently TorchFix doesn't know the types of the objects, so it's hard to find lists of torch.nn.Module
objects.
pyre and TypeInferenceProvider
https://libcst.readthedocs.io/en/latest/metadata.html#libcst.metadata.TypeInferenceProvider can probably help here, but it's a separate feature to implement.
Yikes! Perhaps it can be done on a best-effort basis for some commonly-used class types, such as Linear
, Conv2d
and other subclasses of torch.nn.Module
as found here: https://pytorch.org/docs/stable/nn.html ? Maybe it can be done only for list comprehensions? I would imagine that it's a common idiom that many people use.
I'll contribute this rule. Got it working locally, just waiting for the open PRs to be reviewed/merged.
There is a real-world example in transformers
(impact mitigated by the subsequent add_module
calls). Other than that, the violation of this rule is fairly rare in larger projects, but moderately common in smaller repos (10+ examples)
Inside a model definition, the
torch.nn.Module
objects inside a Python list do not get their parameters registered. Hence such parameters do not get trained by the optimizer, even though they are in the call graph formed by forward(). This should be flagged by torchfix -- currently no warning is given for this issue.Example:
Torchfix output: