mosaicml / composer

Supercharge Your Model Training
http://docs.mosaicml.com
Apache License 2.0
5.15k stars 418 forks source link

Refactor stochastic depth to generalize to some novel models #253

Closed Landanjs closed 1 year ago

Landanjs commented 2 years ago

🚀 Feature Request

Use torch.fx to automatically detect residual blocks and residual connections in a model, then manipulate these components to perform stochastic depth.

Motivation

Right now, stochastic depth replaces a hard coded module (i.e. composer.models.resnets.Bottleneck) with a manually defined stochastic version of the module. The stochastic module is designed to randomly skip the main computation during training and to multiply the residual connection by the probability of skipping.

In order to avoid manually specifying the module to replace and a respective stochastic module, the residual blocks and residual connections need to be automatically identified and manipulated. torch.fx may provide the tools to be able to do this.

Automatic identification and manipulation would allow stochastic depth to be applied to several models without hard coded specification:

Implementation

Vague idea for how to do this:

  1. Identify two-argument add operations in a model architecture
  2. Trace the arguments to a single point in the computation graph -> the start of the residual block
  3. Add a conditional statement based on a Bernoulli variable after the single point in the computation graph
  4. Scale the residual connection by the probability of the entering the conditional statement

Caveat: I don't know if conditional statements can be added with torch.fx.

Discussion

I don't expect this to work for every architecture, but it should be at least more generalizable than the current implementation. Alternatively, we can update stochastic depth with every new target architecture, but this is not sustainable.

@hanlint Do you know of anyone on research eng I could talk to about this?

hanlint commented 2 years ago

Looking at torch.fx, they do support graph re-writes and even subgraph pattern matching (https://pytorch.org/docs/stable/fx.html#torch.fx.replace_pattern). Looks like they return a torch.nn.Module, but we'll need to check its compatibility with other surgery algorithms that we have.

This is probably lower priority for now, given our push towards usability.

mvpatel2000 commented 1 year ago

Closing. Tracking elsewhere as low pri