Closed Landanjs closed 1 year ago
Looking at torch.fx
, they do support graph re-writes and even subgraph pattern matching (https://pytorch.org/docs/stable/fx.html#torch.fx.replace_pattern). Looks like they return a torch.nn.Module
, but we'll need to check its compatibility with other surgery algorithms that we have.
This is probably lower priority for now, given our push towards usability.
Closing. Tracking elsewhere as low pri
🚀 Feature Request
Use torch.fx to automatically detect residual blocks and residual connections in a model, then manipulate these components to perform stochastic depth.
Motivation
Right now, stochastic depth replaces a hard coded module (i.e. composer.models.resnets.Bottleneck) with a manually defined stochastic version of the module. The stochastic module is designed to randomly skip the main computation during training and to multiply the residual connection by the probability of skipping.
In order to avoid manually specifying the module to replace and a respective stochastic module, the residual blocks and residual connections need to be automatically identified and manipulated. torch.fx may provide the tools to be able to do this.
Automatic identification and manipulation would allow stochastic depth to be applied to several models without hard coded specification:
Implementation
Vague idea for how to do this:
add
operations in a model architectureCaveat: I don't know if conditional statements can be added with torch.fx.
Discussion
I don't expect this to work for every architecture, but it should be at least more generalizable than the current implementation. Alternatively, we can update stochastic depth with every new target architecture, but this is not sustainable.
@hanlint Do you know of anyone on research eng I could talk to about this?