Closed RuiWang1998 closed 1 year ago
Hi @crcrpar, Thanks for the comment!
It's very doable to add an additional flag to the class. However I believe this could be the new default behavior simply due to its benefits. For the stability issue however, I would argue that it's as stable as the original since the reduction happens in the forward pass and we do not modify that.
The difference between the two is small enough that the memory cost of the original in my opinion pales a bit.
Nevertheless if you still think it is better to add an additional flag, I'll be very glad to do that!
I hear you and appreciate this improvement very much, but I want to be conservative here. These normalization layers are used by a certain number of models out there. Even if we make this path the default behavior in the end, I do think we should have some period to let the users get aware of the behavioral change is coming by adding an argument and an appropriate warning message
Hi @crcrpar ,
Yeah sure that makes sense! I'll add a flag very soon for this!
Hi @crcrpar ,
I believe these two commits allows an argument for switching behaviors!
Do we need a warning somehow now?
On a side note, I was trying to do the same thing to FastLayerNorm
but keep getting weird bugs even doing something as simple as changing the name of params.x
to params.z
for FwdParams
. Anyways I would really appreciate it if you take a look at it and give me some pointers here: https://github.com/RuiWang1998/apex/blob/19fe4f5c9de5ff32332ba59899c6731cf6bf3423/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh#L96
(To test this, just use the test in the contrib
folder, thanks in advance)
Best, Rui
Hi @crcrpar ,
It seems I have successfully get fast_layer_norm
version to work as well! It turned out to be just some cache issues during compiling.
Since this is changing the computation done on the backward pass somewhat, is there any intuition or data on the expected performance (speed) difference compared to the current implementation?
Hi @eqy ,
Thanks for the comment!
As you can probably see, the total number of FLOPs is almost identical and the read/write tensors have almost the same shape albeit slightly different. For this reason, we believe they should be almost the same performance speed wise, which is consistent with what we have been observing.
Hi @crcrpar ,
I additionally moved memory_efficient
to template, which should in theory make it costless.
Hi @crcrpar ,
I have relaxed the test tolerance a tiny bit for the half precision to pass.
Best.
I had been on vacation, excuse me for the delay
Hi @crcrpar ,
Thanks so much! I don't know how I missed those. Anyways they are fixed now.
Best
Hi,
Content/motivation
This modifies the fused layernorm/rmsnorm implmentation s.t. the output tensor is saved and the input is supposedly freed. The motivation comes from the observation that the output is going to be saved somewhere anyway. For example, in pre-norm Transformers (GPT, Palm, T5, and many others), the normalization is directly followed by a linear projection either in attention or in mlp. In post-norm models, the output is additionally used in residual connection directly. However, in both scenarios, the input of the normalization is left unused, making its original usage here pretty redundant as we can save the output tensor that is going to be saved anyway.
Effect
We observe a substantial reduction in memory usage when training large transformer models: a 80GB-run on one A100 now only requires 70GB memory, which is 1/8 of a reduction in total memory cost. We use a configuration where fewer normalizations are applied and the benefits will be more pronounced in canonical models.
Notes on numerical precision
Since we save the output instead of the input tensor, the normalied tensor $\hat{x}$ needed in the backward is recomputed from output, i.e., $\hat{x}=(y-\beta)/\gamma$ instead of $\hat{x}=\frac{x-E(x)}{\sqrt{VAR(x)-E(x)}}$. As a result we are seeing numerical differences between the current implementation and the proposed change. However, note that the reduction part of the computation happens in the forward pass s.t. we would not have any stability issue whatsoever (we also clamp by magnitude of $\gamma$ for division).
Currently we are failing at two test runs due to numerical discrepancies when using
float16
: