NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
2k stars 333 forks source link

Isn't the memory consumption should be dropped when using fp8? #1261

Open JayC1208 opened 1 month ago

JayC1208 commented 1 month ago

Hi, I am just trying the example provided (https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/te_llama/tutorial_accelerate_hf_llama_with_te.html), with llama 2 model.

As it is 7B model, I assume the GPU memory usage for model should be around 14GB when using fp16 (which is default), and around 7B for fp8. However, it still shows memory usage of 14B (I used model.get_memory_footprint() and nvidia-smi to check allocated memory). Also, when I print out dtype of hidden states of the layers, it shows bfloat16.

Is it normal or is not working well on my side? Please correct me if I understand sth wrong.

Thanks.

timmoon10 commented 1 month ago

Your memory usage is expected, but it also depends on your workflow. First let's talk about FP8 weights. One limitation of FP8 support in Hopper and Lovelace is that the Tensor Cores only support TN GEMMs for FP8 (see cuBLAS docs). This is fine for the forward pass since that's the native format for torch.nn.Linear, but it means that we require transposes for the backward pass. Our default behavior when casting to FP8 is to use a fused cast-transpose kernel, which has better performance at the expense of having the same memory footprint as FP16/BF16. If you are performing inference and you know you will not require any backpropagation, then you can initialize the model within a torch.no_grad context and TE will not allocate memory for the transposes. You can also do this if you are willing to accept the performance penalty from unfused transpose kernels. We are aware this can be unintuitive, so we're working on nicer ways to specify this within the FP8 recipe. Note that these considerations may change with upcoming hardware and low-precision formats.

Next, activation tensors are usually not returned in FP8. We mostly use FP8 as GEMM inputs to take advantage of Tensor Cores, and we're generally cautious about using it for more numerically sensitive operations (e.g. activation functions, LayerNorm, GEMM outputs). Also, the internals of TE modules are somewhat messy, especially how they handle the per-tensor scaling factors needed for FP8 training. If you would like to manually access FP8 values, it might be worth looking into the experimental operation-based API (see the Quantize op from https://github.com/NVIDIA/TransformerEngine/pull/1033), which makes use of a Float8Tensor class that mimics a plain PyTorch tensor.