NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.2k stars 1.36k forks source link

Massively reduce LayerNorm/RMSNorm GPU memory usage in modern networks by tricking torch autograd #1715

Closed RuiWang1998 closed 10 months ago

RuiWang1998 commented 11 months ago

Hi,

Content/motivation

This modifies the fused layernorm/rmsnorm implmentation s.t. the output tensor is saved and the input is supposedly freed. The motivation comes from the observation that the output is going to be saved somewhere anyway. For example, in pre-norm Transformers (GPT, Palm, T5, and many others), the normalization is directly followed by a linear projection either in attention or in mlp. In post-norm models, the output is additionally used in residual connection directly. However, in both scenarios, the input of the normalization is left unused, making its original usage here pretty redundant as we can save the output tensor that is going to be saved anyway.

Effect

We observe a substantial reduction in memory usage when training large transformer models: a 80GB-run on one A100 now only requires 70GB memory, which is 1/8 of a reduction in total memory cost. We use a configuration where fewer normalizations are applied and the benefits will be more pronounced in canonical models.

Notes on numerical precision

Since we save the output instead of the input tensor, the normalied tensor $\hat{x}$ needed in the backward is recomputed from output, i.e., $\hat{x}=(y-\beta)/\gamma$ instead of $\hat{x}=\frac{x-E(x)}{\sqrt{VAR(x)-E(x)}}$. As a result we are seeing numerical differences between the current implementation and the proposed change. However, note that the reduction part of the computation happens in the forward pass s.t. we would not have any stability issue whatsoever (we also clamp by magnitude of $\gamma$ for division).

Currently we are failing at two test runs due to numerical discrepancies when using float16:

======================================================================
FAIL: test_autocast_fused_layer_norm_float16_elementwise_affine_False_cuda_float16 (__main__.TestFusedLayerNormCUDA.test_autocast_fused_layer_norm_float16_elementwise_affine_False_cuda_float16)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/rui/anaconda3/envs/layer-norm/lib/python3.11/site-packages/torch/testing/_internal/common_utils.py", line 2081, in wrapper
    method(*args, **kwargs)
  File "/home/rui/anaconda3/envs/layer-norm/lib/python3.11/site-packages/torch/testing/_internal/common_device_type.py", line 401, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rui/opensource/apex/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py", line 234, in test_autocast_fused_layer_norm
    torch.testing.assert_close(native_x.grad, fused_x.grad, **tols, check_dtype=False)
  File "/home/rui/anaconda3/envs/layer-norm/lib/python3.11/site-packages/torch/testing/_comparison.py", line 1511, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 20 / 8192 (0.2%)
Greatest absolute difference: 4.57763671875e-05 at index (1, 21, 12) (up to 1e-05 allowed)
Greatest relative difference: 0.01692047377326565 at index (7, 25, 8) (up to 0.001 allowed)

======================================================================
FAIL: test_autocast_fused_layer_norm_float16_elementwise_affine_True_cuda_float16 (__main__.TestFusedLayerNormCUDA.test_autocast_fused_layer_norm_float16_elementwise_affine_True_cuda_float16)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/rui/anaconda3/envs/layer-norm/lib/python3.11/site-packages/torch/testing/_internal/common_utils.py", line 2081, in wrapper
    method(*args, **kwargs)
  File "/home/rui/anaconda3/envs/layer-norm/lib/python3.11/site-packages/torch/testing/_internal/common_device_type.py", line 401, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rui/opensource/apex/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py", line 234, in test_autocast_fused_layer_norm
    torch.testing.assert_close(native_x.grad, fused_x.grad, **tols, check_dtype=False)
  File "/home/rui/anaconda3/envs/layer-norm/lib/python3.11/site-packages/torch/testing/_comparison.py", line 1511, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 20 / 8192 (0.2%)
Greatest absolute difference: 4.57763671875e-05 at index (1, 21, 12) (up to 1e-05 allowed)
Greatest relative difference: 0.01692047377326565 at index (7, 25, 8) (up to 0.001 allowed)
RuiWang1998 commented 11 months ago

Hi @crcrpar, Thanks for the comment!

It's very doable to add an additional flag to the class. However I believe this could be the new default behavior simply due to its benefits. For the stability issue however, I would argue that it's as stable as the original since the reduction happens in the forward pass and we do not modify that.

The difference between the two is small enough that the memory cost of the original in my opinion pales a bit.

Nevertheless if you still think it is better to add an additional flag, I'll be very glad to do that!

crcrpar commented 11 months ago

I hear you and appreciate this improvement very much, but I want to be conservative here. These normalization layers are used by a certain number of models out there. Even if we make this path the default behavior in the end, I do think we should have some period to let the users get aware of the behavioral change is coming by adding an argument and an appropriate warning message

RuiWang1998 commented 11 months ago

Hi @crcrpar ,

Yeah sure that makes sense! I'll add a flag very soon for this!

RuiWang1998 commented 11 months ago

Hi @crcrpar ,

I believe these two commits allows an argument for switching behaviors!

Do we need a warning somehow now?

On a side note, I was trying to do the same thing to FastLayerNorm but keep getting weird bugs even doing something as simple as changing the name of params.x to params.z for FwdParams. Anyways I would really appreciate it if you take a look at it and give me some pointers here: https://github.com/RuiWang1998/apex/blob/19fe4f5c9de5ff32332ba59899c6731cf6bf3423/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh#L96 (To test this, just use the test in the contrib folder, thanks in advance)

Best, Rui

RuiWang1998 commented 11 months ago

Hi @crcrpar ,

It seems I have successfully get fast_layer_norm version to work as well! It turned out to be just some cache issues during compiling.

eqy commented 11 months ago

Since this is changing the computation done on the backward pass somewhat, is there any intuition or data on the expected performance (speed) difference compared to the current implementation?

RuiWang1998 commented 11 months ago

Hi @eqy ,

Thanks for the comment!

As you can probably see, the total number of FLOPs is almost identical and the read/write tensors have almost the same shape albeit slightly different. For this reason, we believe they should be almost the same performance speed wise, which is consistent with what we have been observing.

RuiWang1998 commented 10 months ago

Hi @crcrpar ,

I additionally moved memory_efficient to template, which should in theory make it costless.

RuiWang1998 commented 10 months ago

Hi @crcrpar ,

I have relaxed the test tolerance a tiny bit for the half precision to pass.

Best.

crcrpar commented 10 months ago

I had been on vacation, excuse me for the delay

RuiWang1998 commented 10 months ago

Hi @crcrpar ,

Thanks so much! I don't know how I missed those. Anyways they are fixed now.

Best