NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.6k stars 255 forks source link

AttnFuncWithCP can use less memory #952

Open i4never opened 1 week ago

i4never commented 1 week ago

In AttnFuncWithCP.forward, up to 3 buffer (two calculating and one recving) is used simultaneously but p2p_comm_buffers will grow to cp_size which I believe is a waste. Just add buffer[i] = None and kv_input[i%2] = None after _flash_attn_forward may lower memory usage as the reference count of calculated kv is zero.

Here's some mock code:

torch.cuda.memory._record_memory_history(max_entries=100000)

device = 'cuda:7' cp_size = 16 streams = [torch.cuda.currentstream(), torch.cuda.Stream()] buffers = [None for in range(cp_size)] buffers[0] = torch.empty(100, 100, dtype=torch.float32, device=device)

torch.cuda.memory._dump_snapshot('dev_cuda_mem_1.pkl') for i in range(cp_size): with torch.cuda.stream(streams[i % 2]): if i < cp_size - 1: buffers[i + 1] = torch.empty(100, 100, dtype=torch.float32, device=device) torch.cuda.memory._dump_snapshot('cuda_mem_snapshot.pkl')

    kv = buffers[i]
    # ....
    for _ in range(10000):
        buffers[i] *= buffers[i]

    # buffers[i] = None
<img width="1366" alt="image" src="https://github.com/NVIDIA/TransformerEngine/assets/10850020/ee0f12c5-735d-40f9-895a-5fc16c962785">

- With buffer release:
```python
import torch

torch.cuda.memory._record_memory_history(max_entries=100000)

device = 'cuda:7'
cp_size = 16
streams = [torch.cuda.current_stream(), torch.cuda.Stream()]
buffers = [None for _ in range(cp_size)]
buffers[0] = torch.empty(100, 100, dtype=torch.float32, device=device)

torch.cuda.memory._dump_snapshot('dev_cuda_mem_1.pkl')
for i in range(cp_size):
    with torch.cuda.stream(streams[i % 2]):
        if i < cp_size - 1:
            buffers[i + 1] = torch.empty(100, 100, dtype=torch.float32, device=device)
            torch.cuda.memory._dump_snapshot('cuda_mem_snapshot.pkl')

        kv = buffers[i]
        # ....
        for _ in range(10000):
            buffers[i] *= buffers[i]

        buffers[i] = None
image

Link to PR #951