A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
In AttnFuncWithCP.forward, up to 3 buffer (two calculating and one recving) is used simultaneously but p2p_comm_buffers will grow to cp_size which I believe is a waste. Just add buffer[i] = None and kv_input[i%2] = None after _flash_attn_forward may lower memory usage as the reference count of calculated kv is zero.
torch.cuda.memory._dump_snapshot('dev_cuda_mem_1.pkl')
for i in range(cp_size):
with torch.cuda.stream(streams[i % 2]):
if i < cp_size - 1:
buffers[i + 1] = torch.empty(100, 100, dtype=torch.float32, device=device)
torch.cuda.memory._dump_snapshot('cuda_mem_snapshot.pkl')
kv = buffers[i]
# ....
for _ in range(10000):
buffers[i] *= buffers[i]
# buffers[i] = None
<img width="1366" alt="image" src="https://github.com/NVIDIA/TransformerEngine/assets/10850020/ee0f12c5-735d-40f9-895a-5fc16c962785">
- With buffer release:
```python
import torch
torch.cuda.memory._record_memory_history(max_entries=100000)
device = 'cuda:7'
cp_size = 16
streams = [torch.cuda.current_stream(), torch.cuda.Stream()]
buffers = [None for _ in range(cp_size)]
buffers[0] = torch.empty(100, 100, dtype=torch.float32, device=device)
torch.cuda.memory._dump_snapshot('dev_cuda_mem_1.pkl')
for i in range(cp_size):
with torch.cuda.stream(streams[i % 2]):
if i < cp_size - 1:
buffers[i + 1] = torch.empty(100, 100, dtype=torch.float32, device=device)
torch.cuda.memory._dump_snapshot('cuda_mem_snapshot.pkl')
kv = buffers[i]
# ....
for _ in range(10000):
buffers[i] *= buffers[i]
buffers[i] = None
In
AttnFuncWithCP.forward
, up to 3 buffer (two calculating and one recving) is used simultaneously butp2p_comm_buffers
will grow tocp_size
which I believe is a waste. Just addbuffer[i] = None
andkv_input[i%2] = None
after_flash_attn_forward
may lower memory usage as the reference count of calculated kv is zero.Here's some mock code:
torch.cuda.memory._record_memory_history(max_entries=100000)
device = 'cuda:7' cp_size = 16 streams = [torch.cuda.currentstream(), torch.cuda.Stream()] buffers = [None for in range(cp_size)] buffers[0] = torch.empty(100, 100, dtype=torch.float32, device=device)
torch.cuda.memory._dump_snapshot('dev_cuda_mem_1.pkl') for i in range(cp_size): with torch.cuda.stream(streams[i % 2]): if i < cp_size - 1: buffers[i + 1] = torch.empty(100, 100, dtype=torch.float32, device=device) torch.cuda.memory._dump_snapshot('cuda_mem_snapshot.pkl')
Link to PR #951