This issue may or may not contain my notes about implementing and training reversble models.
From correspondence with @TimDettmers, @yhn112 , @borzunov, @mryab
Why? Reversible models are one of the few ways to fit large transformers into low-end DL rigs (single gpu, 8-12gb vram, 16-32gb ram). The alternatives are nested checkpointing [slower], checkpoint quantization [may affect convergence, still more memory], or checkpoint offloading [ram is already used up for master params and optimizer]. Hence, reversibles are the most memory-efficient training strategy if they can match the quality of regular models.
As of c076ba7 , we support reversible=True which triggers reversible transformer as it was defined in Reformer. However, this is not the only possible way. Existing alternatives are:
Running both Attention and FFN in each branch (the default reformer has only attn in F branch and only FFN in G branch)
source: running average coupling
```
class MeanCoupling:
def __init__(self, layer_index: int):
assert layer_index > 0, "layer with index zero must be applied before reversible (e.g. embeddings)"
self.layer_index = layer_index
def forward(self, other_stream: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor:
i = self.layer_index
return other_stream * (i / (i + 1)) + fn_out * (1 / (i + 1))
def inverse(self, forward_output: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor:
i = self.layer_index
return (forward_output - fn_out * (1 / (i + 1))) / (i / (i + 1))
```
Furthermore, it is unclear how best to use reversible's two inputs and two outputs with transformers. If we cannot afford to double the layer sizes, this raises the following questions:
which sub-branch of the reversible modules is more informative? X, Y, X+Y, X+Y-common_input, some learned gates?
how best to instantiate the two inputs: both equal or set one to zeros?
Finally, there are implementation hacks that can affect the training throughput and memory requirements:
couplings are elementwise operations: should we apply jit-script?
FFN layer can be computed a few vectors at a time (over batch size and sequence length), originally proposed in Reformer
Attention layer can be computed one head at a time - since it's additive w.r.t. heads - but that may be difficult with some sparse/lowrank projections.
model: equivalent of GPT-3 medium (tested learning curve match with transformer lm) + rotary embeddings
training config: based on GPT-3 paper hparams, see details at the end. All model / training hyperparams are kept equal except for reversible type.
Limitations:
some runs were stopped prematurely due to poor performance. There is a small but nonzero chance they could recover.
we are using an X+Y-Initial hack for reformer, which is standard, but has not been a part of the original paper.
we did not tune hyperparameters individually for each model. There is a chance that some reversibles need different hparams.
we used beta=0.9 for momentum-based models (default from repo and paper) and did not tune it.
Results:
tl;dr reformer is still the best
baseline = GPT3-medium + rotary
baseline-rev = GPT3-medium with reformer-like layer order - closest, but still marginally worse than baseline. The margin is ~1/5 that of switching to GPT3-XL (double hidden size and heads).
baseline-rev-aaff = GPT3-medium with reformer-like layer order, but using [Attn, Attn, FFN, FFN] layers instead of [Attn, FFN, Attn, FFN], so that each of two reversible-residual branches gets both attention and FFN outputs. Starts slightly worse, and never catches up.
baseline-rev0.9-out0 same as previous row, but using the X branch as model outputs (aka the cummulative sum of moving averages). Began training faster, but diverged just after passing 4k steps. Hypothesis: revlib default momentum may have numeric instability because of the 1 / beta ** num_layers term. Will investigate the alternative momentumnet implementations to eliminate this hypothesis.
baseline-rev-average - simple running average, using the MeanCoupling code from the first message. Trains significantly slower than the baseline. Killed after 3k steps.
Validation loss closely tracks the training loss.
Gradient norm supports the hypothesis for numerical instability of momentum-reversible models
This issue may or may not contain my notes about implementing and training reversble models. From correspondence with @TimDettmers, @yhn112 , @borzunov, @mryab
Why? Reversible models are one of the few ways to fit large transformers into low-end DL rigs (single gpu, 8-12gb vram, 16-32gb ram). The alternatives are nested checkpointing [slower], checkpoint quantization [may affect convergence, still more memory], or checkpoint offloading [ram is already used up for master params and optimizer]. Hence, reversibles are the most memory-efficient training strategy if they can match the quality of regular models.
As of c076ba7 , we support reversible=True which triggers reversible transformer as it was defined in Reformer. However, this is not the only possible way. Existing alternatives are:
Running both Attention and FFN in each branch (the default reformer has only attn in F branch and only FFN in G branch)
MomentumNet: https://arxiv.org/abs/1907.11818
Adjusted MomentumNet from homebrewnlp - see Lucas' notes from revlib - ported to LearnTransformer
Simple running average: code in LearnTransformer
source: running average coupling
``` class MeanCoupling: def __init__(self, layer_index: int): assert layer_index > 0, "layer with index zero must be applied before reversible (e.g. embeddings)" self.layer_index = layer_index def forward(self, other_stream: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor: i = self.layer_index return other_stream * (i / (i + 1)) + fn_out * (1 / (i + 1)) def inverse(self, forward_output: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor: i = self.layer_index return (forward_output - fn_out * (1 / (i + 1))) / (i / (i + 1)) ```Furthermore, it is unclear how best to use reversible's two inputs and two outputs with transformers. If we cannot afford to double the layer sizes, this raises the following questions:
Finally, there are implementation hacks that can affect the training throughput and memory requirements: