Describe the bug
Cannot get the same attention calculation results from deepspeed's dense_blocked_attention and flash_attn's flash_attn_varlen_func. I just give my own params.
To Reproduce
# the following functions are same with the ones in
`deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py`.
import itertools
from typing import List, Tuple
import pytest
import torch
torch.manual_seed(42)
from deepspeed.accelerator import get_accelerator
from deepspeed.inference.v2.modules import ConfigBundle
from deepspeed.inference.v2.modules.configs import DSSelfAttentionConfig, PositionalEmbeddingType, RotateHalfConfig
from deepspeed.inference.v2.modules.interfaces import DSSelfAttentionRegistry, DSSelfAttentionBase
import random
from typing import List, Optional, Tuple
import torch
from deepspeed.accelerator import get_accelerator
from deepspeed.inference.v2.ragged import (
AllocationMode,
DSSequenceDescriptor,
DSStateManager,
DSStateManagerConfig,
KVCacheConfig,
MemoryConfig,
PlaceholderSequenceDescriptor,
RaggedBatchWrapper,
)
def build_batch_and_manager(
seq_params: List[Tuple[int, int]],
head_size: int,
n_heads_kv: int,
kv_block_size: int,
vocab_range: Optional[int] = 100,
padding: Optional[bool] = False,
kv_fill: Optional[List[torch.Tensor]] = None
) -> Tuple[RaggedBatchWrapper, DSStateManager, List[DSSequenceDescriptor]]:
seq_lens = [seq_param[0] for seq_param in seq_params]
fill_lens = [seq_param[1] for seq_param in seq_params]
max_created_batch_len = max(sum(seq_lens), sum(fill_lens))
total_tokens = max(max_created_batch_len, 1024)
n_seqs = max(len(seq_lens), 128)
req_kv_blocks = [None] * n_seqs
total_kv_blocks = 0
for i, (seq_len, n_seen_tokens) in enumerate(seq_params):
req_kv_blocks[i] = (seq_len + n_seen_tokens + kv_block_size - 1) // kv_block_size
total_kv_blocks += req_kv_blocks[i]
kv_config = KVCacheConfig(block_size=kv_block_size,
num_allocation_groups=1,
cache_shape=(1, n_heads_kv, head_size))
memory_config = MemoryConfig(mode=AllocationMode.ALLOCATE, size=total_kv_blocks)
config = DSStateManagerConfig(max_tracked_sequences=n_seqs,
max_ragged_sequence_count=n_seqs,
max_ragged_batch_size=total_tokens,
memory_config=memory_config)
batch = RaggedBatchWrapper(config)
state_manager = DSStateManager(config, (kv_config, ))
all_allocs = []
for _ in range(20):
decision = random.randint(0, 1)
if decision == 0:
blocks_to_allocate = random.randint(0, total_kv_blocks)
if blocks_to_allocate <= state_manager.free_blocks[0] and blocks_to_allocate > 0:
all_allocs.append(state_manager.allocate_blocks(blocks_to_allocate))
else:
if len(all_allocs) > 0:
idx = random.randint(0, len(all_allocs) - 1)
state_manager._kv_cache.free(all_allocs[idx])
del all_allocs[idx]
for alloc in all_allocs:
state_manager._kv_cache.free(alloc)
assert state_manager.free_blocks[0] == total_kv_blocks
batch.clear()
seq_descs = []
if kv_fill is None or sum(fill_lens) == 0:
for i, (seq_len, n_seen_tokens) in enumerate(seq_params):
# Create empty descriptor
seq_desc = state_manager.get_or_create_sequence(i)
# Update `seen_tokens` in the descriptor
seq_desc.pre_forward(n_seen_tokens)
seq_desc.post_forward()
# Ensure there's enough KV-cache for the sequence
kv_block_ids = state_manager.allocate_blocks(req_kv_blocks[i])
print(f"Allocated {req_kv_blocks[i]} blocks for sequence {i}: {kv_block_ids}")
seq_desc.extend_kv_cache(kv_block_ids)
# Insert sequence into batch
tokens = torch.randint(0, vocab_range, (seq_len, ))
batch.insert_sequence(seq_desc, tokens)
seq_desc.pre_forward(seq_len)
seq_descs.append(seq_desc)
else:
qkv = torch.empty((total_tokens, (n_heads_kv * 3) * head_size),
dtype=torch.float16,
device=get_accelerator().current_device())
fills_as_tensor = torch.tensor(fill_lens, dtype=torch.int32)
fill_cumsum = torch.cat((torch.tensor([0], dtype=torch.int32), torch.cumsum(fills_as_tensor, dim=0)))
for i, (_, n_seen_tokens) in enumerate(seq_params):
# Create empty descriptor
seq_desc = state_manager.get_or_create_sequence(i)
# Update `seen_tokens` in the descriptor
if n_seen_tokens > 0:
dummy_fill_toks = torch.randint(0, vocab_range, (n_seen_tokens, ))
batch.insert_sequence(seq_desc, dummy_fill_toks)
seq_desc.pre_forward(n_seen_tokens)
# Ensure there's enough KV-cache for the sequence
kv_block_ids = state_manager.allocate_blocks(req_kv_blocks[i])
print(f"Allocated {req_kv_blocks[i]} blocks for sequence {i}: {kv_block_ids}")
seq_desc.extend_kv_cache(kv_block_ids)
seq_descs.append(seq_desc)
if n_seen_tokens == 0:
continue
assert kv_fill[i].shape[0] == n_seen_tokens
assert kv_fill[i].shape[1] == n_heads_kv * head_size * 2
local_q = torch.randn((n_seen_tokens, n_heads_kv * head_size), dtype=torch.float16, device=qkv.device)
local_qkv = torch.cat((local_q, kv_fill[i]), dim=1)
qkv[fill_cumsum[i]:fill_cumsum[i + 1]] = local_qkv
batch.finalize(padding=padding)
from deepspeed.inference.v2.kernels.ragged_ops import LinearBlockedKVCopy
kv_copy = LinearBlockedKVCopy(head_size, n_heads_kv, n_heads_kv, torch.float16)
kv_cache = state_manager.get_cache(0)
kv_copy(kv_cache, qkv, batch)
for seq_desc in seq_descs:
if seq_desc.in_flight_tokens > 0:
seq_desc.post_forward()
batch.clear()
for i, (seq_len, _) in enumerate(seq_params):
seq_desc = state_manager.get_or_create_sequence(i)
tokens = torch.randint(0, vocab_range, (seq_len, ))
batch.insert_sequence(seq_desc, tokens)
seq_desc.pre_forward(seq_len)
# We will skip KV cache allocation here because we did a lump allocation above
# for both the fill and the sequence itself.
batch.finalize(padding=padding)
return batch, state_manager, seq_descs
# all_close
from typing import Tuple
import torch
from deepspeed.accelerator import get_accelerator
TOLERANCES = None
def get_tolerances():
global TOLERANCES
if TOLERANCES is None:
TOLERANCES = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}
if get_accelerator().is_bf16_supported():
# Note: BF16 tolerance is higher than FP16 because of the lower precision (7 (+1) bits vs
# 10 (+1) bits)
TOLERANCES[torch.bfloat16] = (4.8e-1, 3.2e-2)
return TOLERANCES
def allclose(x, y, tolerances: Tuple[int, int] = None):
assert x.dtype == y.dtype
if tolerances is None:
rtol, atol = get_tolerances()[x.dtype]
else:
rtol, atol = tolerances
return torch.allclose(x, y, rtol=rtol, atol=atol)
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
validate_accuracy = True
print("Using Flash Attention for validation")
except ImportError:
validate_accuracy = False
print("Not using Flash Attention")
def _blocked_flash_testing_helper(head_size: int,
n_heads_q: int,
n_heads_kv: int,
seq_params: List[Tuple[int, int]],
trained_freqs: bool = None) -> None:
"""
Helper function for testing blocked flash attention. This implementation is based on
the implemnentation in ``unit.inference.kernels.ragged_ops.test_blocked_flash`` but
integrates functionality to validate the composability.
"""
if trained_freqs is None:
embed_type = PositionalEmbeddingType.none
embed_args = None
else:
embed_type = PositionalEmbeddingType.rotate_half
embed_args = RotateHalfConfig(use_trained_freqs=trained_freqs)
attn_config = DSSelfAttentionConfig(max_tokens=2048,
n_heads_q=n_heads_q,
n_heads_kv=n_heads_kv,
head_size=head_size,
max_sequences=32,
positional_embedding_type=embed_type,
positional_embedding_config=embed_args)
config = ConfigBundle(name='dense_blocked_attention', config=attn_config)
attn_module: DSSelfAttentionBase = DSSelfAttentionRegistry.instantiate_config(config)
kv_block_size = attn_module.kv_block_size
kvs = []
for _, history_len in seq_params:
if history_len > 0:
kvs.append(
torch.randn((history_len, 2 * n_heads_kv * head_size),
device=get_accelerator().current_device(),
dtype=torch.float16))
else:
kvs.append(None)
batch, state_manager, _ = build_batch_and_manager(seq_params, head_size, n_heads_kv, kv_block_size, kv_fill=kvs)
qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size),
device=get_accelerator().current_device(),
dtype=torch.float16)
kv_cache = state_manager.get_cache(0)
attn_module.build_atoms(batch)
if not trained_freqs:
out = attn_module(qkv, kv_cache, batch)
else:
inv_freqs = torch.randn((head_size // 2, ), device=get_accelerator().current_device(), dtype=torch.float16)
out = attn_module(qkv, kv_cache, batch, inv_freqs)
if validate_accuracy and trained_freqs is None:
cu_seqlens_q = torch.tensor([0] + list(itertools.accumulate([seq[0] for seq in seq_params])),
dtype=torch.int32,
device=get_accelerator().current_device())
cu_seqlens_kv = torch.tensor([0] + list(itertools.accumulate([seq[1] + seq[0] for seq in seq_params])),
dtype=torch.int32,
device=get_accelerator().current_device())
inflight_kv = qkv[:, head_size * n_heads_q:]
full_kvs = []
for i, kv in enumerate(kvs):
if kv is not None:
full_kvs.append(torch.cat([kv, inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]]], dim=0))
else:
full_kvs.append(inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]])
run_kvs = torch.cat(full_kvs, dim=0)
k = run_kvs[:, :head_size * n_heads_kv]
v = run_kvs[:, head_size * n_heads_kv:]
q = qkv[:, :head_size * n_heads_q]
q_ref = q.reshape((batch.current_tokens, n_heads_q, head_size))
k_ref = k.reshape((k.shape[0], n_heads_kv, head_size))
v_ref = v.reshape((v.shape[0], n_heads_kv, head_size))
max_seqlen_q = max([seq[0] for seq in seq_params])
max_seqlen_kv = max([seq[1] + seq[0] for seq in seq_params])
ref_o = flash_attn_varlen_func(q_ref,
k_ref,
v_ref,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
softmax_scale=1.0,
causal=True)
ref_o = ref_o.reshape(batch.current_tokens, head_size * n_heads_q)
assert allclose(out, ref_o)
get_accelerator().synchronize()
head_size = 128
n_heads_q = 40
n_heads_kv = 40
seq_params = [(332, 628)]
print(_blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params, trained_freqs=None))
Expected behavior
No AssertionError: out and ref_o should be close to each other.
ds_report output
[2024-01-17 06:42:46,418] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
[WARNING] using untested triton version (2.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/usr/local/lib/python3.10/dist-packages/torch']
torch version .................... 2.1.0a0+32f93b1
deepspeed install path ........... ['/usr/local/lib/python3.10/dist-packages/deepspeed']
deepspeed info ................... 0.12.7+81cc3207, 81cc3207, xinji1
torch cuda version ............... 12.2
torch hip version ................ None
nvcc version ..................... 12.2
deepspeed wheel compiled w. ...... torch 2.1, cuda 12.2
shared memory (/dev/shm) size .... 216.50 GB
Screenshots
If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
OS: Ubuntu 22.04
GPU count and types : A100-80g single
(if applicable) what DeepSpeed-MII version are you using
(if applicable) Hugging Face Transformers/Accelerate/etc. versions: Transformers 4.36.2
Describe the bug Cannot get the same attention calculation results from deepspeed's
dense_blocked_attention
and flash_attn'sflash_attn_varlen_func
. I just give my own params.To Reproduce
Expected behavior No AssertionError:
out
andref_o
should be close to each other.ds_report output
Screenshots If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
Docker context Are you using a specific docker image that you can share? https://hub.docker.com/repository/docker/xinji1/tensorrt_llm/general
Additional context I've tried both
causal=True
andcausal=False
for the latter flash_attn_varlen_func, while cannot get the expected results.