microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.34k stars 4.1k forks source link

[BUG] cannot get the same attention calculation results from deepspeed/dense_blocked_attention and flash_attn/flash_attn_varlen_func #4967

Closed xinji1 closed 9 months ago

xinji1 commented 9 months ago

Describe the bug Cannot get the same attention calculation results from deepspeed's dense_blocked_attention and flash_attn's flash_attn_varlen_func. I just give my own params.

To Reproduce


# the following functions are same with the ones in  
`deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py`.

import itertools
from typing import List, Tuple
import pytest
import torch
torch.manual_seed(42)
from deepspeed.accelerator import get_accelerator
from deepspeed.inference.v2.modules import ConfigBundle
from deepspeed.inference.v2.modules.configs import DSSelfAttentionConfig, PositionalEmbeddingType, RotateHalfConfig
from deepspeed.inference.v2.modules.interfaces import DSSelfAttentionRegistry, DSSelfAttentionBase

import random
from typing import List, Optional, Tuple

import torch

from deepspeed.accelerator import get_accelerator
from deepspeed.inference.v2.ragged import (
    AllocationMode,
    DSSequenceDescriptor,
    DSStateManager,
    DSStateManagerConfig,
    KVCacheConfig,
    MemoryConfig,
    PlaceholderSequenceDescriptor,
    RaggedBatchWrapper,
)
def build_batch_and_manager(
    seq_params: List[Tuple[int, int]],
    head_size: int,
    n_heads_kv: int,
    kv_block_size: int,
    vocab_range: Optional[int] = 100,
    padding: Optional[bool] = False,
    kv_fill: Optional[List[torch.Tensor]] = None
) -> Tuple[RaggedBatchWrapper, DSStateManager, List[DSSequenceDescriptor]]:

    seq_lens = [seq_param[0] for seq_param in seq_params]
    fill_lens = [seq_param[1] for seq_param in seq_params]
    max_created_batch_len = max(sum(seq_lens), sum(fill_lens))
    total_tokens = max(max_created_batch_len, 1024)
    n_seqs = max(len(seq_lens), 128)

    req_kv_blocks = [None] * n_seqs
    total_kv_blocks = 0
    for i, (seq_len, n_seen_tokens) in enumerate(seq_params):
        req_kv_blocks[i] = (seq_len + n_seen_tokens + kv_block_size - 1) // kv_block_size
        total_kv_blocks += req_kv_blocks[i]

    kv_config = KVCacheConfig(block_size=kv_block_size,
                              num_allocation_groups=1,
                              cache_shape=(1, n_heads_kv, head_size))
    memory_config = MemoryConfig(mode=AllocationMode.ALLOCATE, size=total_kv_blocks)

    config = DSStateManagerConfig(max_tracked_sequences=n_seqs,
                                  max_ragged_sequence_count=n_seqs,
                                  max_ragged_batch_size=total_tokens,
                                  memory_config=memory_config)

    batch = RaggedBatchWrapper(config)
    state_manager = DSStateManager(config, (kv_config, ))

    all_allocs = []
    for _ in range(20):
        decision = random.randint(0, 1)

        if decision == 0:
            blocks_to_allocate = random.randint(0, total_kv_blocks)
            if blocks_to_allocate <= state_manager.free_blocks[0] and blocks_to_allocate > 0:
                all_allocs.append(state_manager.allocate_blocks(blocks_to_allocate))
        else:
            if len(all_allocs) > 0:
                idx = random.randint(0, len(all_allocs) - 1)
                state_manager._kv_cache.free(all_allocs[idx])

                del all_allocs[idx]

    for alloc in all_allocs:
        state_manager._kv_cache.free(alloc)

    assert state_manager.free_blocks[0] == total_kv_blocks

    batch.clear()
    seq_descs = []

    if kv_fill is None or sum(fill_lens) == 0:
        for i, (seq_len, n_seen_tokens) in enumerate(seq_params):
            # Create empty descriptor
            seq_desc = state_manager.get_or_create_sequence(i)

            # Update `seen_tokens` in the descriptor
            seq_desc.pre_forward(n_seen_tokens)
            seq_desc.post_forward()

            # Ensure there's enough KV-cache for the sequence
            kv_block_ids = state_manager.allocate_blocks(req_kv_blocks[i])
            print(f"Allocated {req_kv_blocks[i]} blocks for sequence {i}: {kv_block_ids}")
            seq_desc.extend_kv_cache(kv_block_ids)

            # Insert sequence into batch
            tokens = torch.randint(0, vocab_range, (seq_len, ))
            batch.insert_sequence(seq_desc, tokens)
            seq_desc.pre_forward(seq_len)
            seq_descs.append(seq_desc)
    else:
        qkv = torch.empty((total_tokens, (n_heads_kv * 3) * head_size),
                          dtype=torch.float16,
                          device=get_accelerator().current_device())
        fills_as_tensor = torch.tensor(fill_lens, dtype=torch.int32)
        fill_cumsum = torch.cat((torch.tensor([0], dtype=torch.int32), torch.cumsum(fills_as_tensor, dim=0)))

        for i, (_, n_seen_tokens) in enumerate(seq_params):
            # Create empty descriptor
            seq_desc = state_manager.get_or_create_sequence(i)

            # Update `seen_tokens` in the descriptor
            if n_seen_tokens > 0:
                dummy_fill_toks = torch.randint(0, vocab_range, (n_seen_tokens, ))
                batch.insert_sequence(seq_desc, dummy_fill_toks)
                seq_desc.pre_forward(n_seen_tokens)

            # Ensure there's enough KV-cache for the sequence
            kv_block_ids = state_manager.allocate_blocks(req_kv_blocks[i])
            print(f"Allocated {req_kv_blocks[i]} blocks for sequence {i}: {kv_block_ids}")
            seq_desc.extend_kv_cache(kv_block_ids)
            seq_descs.append(seq_desc)

            if n_seen_tokens == 0:
                continue

            assert kv_fill[i].shape[0] == n_seen_tokens
            assert kv_fill[i].shape[1] == n_heads_kv * head_size * 2

            local_q = torch.randn((n_seen_tokens, n_heads_kv * head_size), dtype=torch.float16, device=qkv.device)
            local_qkv = torch.cat((local_q, kv_fill[i]), dim=1)
            qkv[fill_cumsum[i]:fill_cumsum[i + 1]] = local_qkv

        batch.finalize(padding=padding)

        from deepspeed.inference.v2.kernels.ragged_ops import LinearBlockedKVCopy
        kv_copy = LinearBlockedKVCopy(head_size, n_heads_kv, n_heads_kv, torch.float16)
        kv_cache = state_manager.get_cache(0)
        kv_copy(kv_cache, qkv, batch)

        for seq_desc in seq_descs:
            if seq_desc.in_flight_tokens > 0:
                seq_desc.post_forward()

        batch.clear()

        for i, (seq_len, _) in enumerate(seq_params):
            seq_desc = state_manager.get_or_create_sequence(i)
            tokens = torch.randint(0, vocab_range, (seq_len, ))
            batch.insert_sequence(seq_desc, tokens)
            seq_desc.pre_forward(seq_len)

            # We will skip KV cache allocation here because we did a lump allocation above
            # for both the fill and the sequence itself.

    batch.finalize(padding=padding)

    return batch, state_manager, seq_descs

# all_close

from typing import Tuple

import torch
from deepspeed.accelerator import get_accelerator

TOLERANCES = None

def get_tolerances():
    global TOLERANCES
    if TOLERANCES is None:
        TOLERANCES = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}
        if get_accelerator().is_bf16_supported():
            # Note: BF16 tolerance is higher than FP16 because of the lower precision (7 (+1) bits vs
            # 10 (+1) bits)
            TOLERANCES[torch.bfloat16] = (4.8e-1, 3.2e-2)
    return TOLERANCES
def allclose(x, y, tolerances: Tuple[int, int] = None):
    assert x.dtype == y.dtype
    if tolerances is None:
        rtol, atol = get_tolerances()[x.dtype]
    else:
        rtol, atol = tolerances
    return torch.allclose(x, y, rtol=rtol, atol=atol)

try:
    from flash_attn.flash_attn_interface import flash_attn_varlen_func
    validate_accuracy = True
    print("Using Flash Attention for validation")
except ImportError:
    validate_accuracy = False
    print("Not using Flash Attention")

def _blocked_flash_testing_helper(head_size: int,
                                  n_heads_q: int,
                                  n_heads_kv: int,
                                  seq_params: List[Tuple[int, int]],
                                  trained_freqs: bool = None) -> None:
    """
    Helper function for testing blocked flash attention. This implementation is based on
    the implemnentation in ``unit.inference.kernels.ragged_ops.test_blocked_flash`` but
    integrates functionality to validate the composability.
    """
    if trained_freqs is None:
        embed_type = PositionalEmbeddingType.none
        embed_args = None
    else:
        embed_type = PositionalEmbeddingType.rotate_half
        embed_args = RotateHalfConfig(use_trained_freqs=trained_freqs)

    attn_config = DSSelfAttentionConfig(max_tokens=2048,
                                        n_heads_q=n_heads_q,
                                        n_heads_kv=n_heads_kv,
                                        head_size=head_size,
                                        max_sequences=32,
                                        positional_embedding_type=embed_type,
                                        positional_embedding_config=embed_args)

    config = ConfigBundle(name='dense_blocked_attention', config=attn_config)
    attn_module: DSSelfAttentionBase = DSSelfAttentionRegistry.instantiate_config(config)

    kv_block_size = attn_module.kv_block_size

    kvs = []
    for _, history_len in seq_params:
        if history_len > 0:
            kvs.append(
                torch.randn((history_len, 2 * n_heads_kv * head_size),
                            device=get_accelerator().current_device(),
                            dtype=torch.float16))
        else:
            kvs.append(None)

    batch, state_manager, _ = build_batch_and_manager(seq_params, head_size, n_heads_kv, kv_block_size, kv_fill=kvs)

    qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size),
                      device=get_accelerator().current_device(),
                      dtype=torch.float16)

    kv_cache = state_manager.get_cache(0)

    attn_module.build_atoms(batch)
    if not trained_freqs:
        out = attn_module(qkv, kv_cache, batch)
    else:
        inv_freqs = torch.randn((head_size // 2, ), device=get_accelerator().current_device(), dtype=torch.float16)
        out = attn_module(qkv, kv_cache, batch, inv_freqs)

    if validate_accuracy and trained_freqs is None:
        cu_seqlens_q = torch.tensor([0] + list(itertools.accumulate([seq[0] for seq in seq_params])),
                                    dtype=torch.int32,
                                    device=get_accelerator().current_device())
        cu_seqlens_kv = torch.tensor([0] + list(itertools.accumulate([seq[1] + seq[0] for seq in seq_params])),
                                     dtype=torch.int32,
                                     device=get_accelerator().current_device())

        inflight_kv = qkv[:, head_size * n_heads_q:]
        full_kvs = []
        for i, kv in enumerate(kvs):
            if kv is not None:
                full_kvs.append(torch.cat([kv, inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]]], dim=0))
            else:
                full_kvs.append(inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]])
        run_kvs = torch.cat(full_kvs, dim=0)
        k = run_kvs[:, :head_size * n_heads_kv]
        v = run_kvs[:, head_size * n_heads_kv:]

        q = qkv[:, :head_size * n_heads_q]
        q_ref = q.reshape((batch.current_tokens, n_heads_q, head_size))
        k_ref = k.reshape((k.shape[0], n_heads_kv, head_size))
        v_ref = v.reshape((v.shape[0], n_heads_kv, head_size))

        max_seqlen_q = max([seq[0] for seq in seq_params])
        max_seqlen_kv = max([seq[1] + seq[0] for seq in seq_params])

        ref_o = flash_attn_varlen_func(q_ref,
                                       k_ref,
                                       v_ref,
                                       cu_seqlens_q,
                                       cu_seqlens_kv,
                                       max_seqlen_q,
                                       max_seqlen_kv,
                                       softmax_scale=1.0,
                                       causal=True)

        ref_o = ref_o.reshape(batch.current_tokens, head_size * n_heads_q)

        assert allclose(out, ref_o)

    get_accelerator().synchronize()

head_size = 128
n_heads_q = 40 
n_heads_kv = 40

seq_params = [(332, 628)]

print(_blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params, trained_freqs=None))

Expected behavior No AssertionError: out and ref_o should be close to each other.

ds_report output

[2024-01-17 06:42:46,418] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
 [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/usr/local/lib/python3.10/dist-packages/torch']
torch version .................... 2.1.0a0+32f93b1
deepspeed install path ........... ['/usr/local/lib/python3.10/dist-packages/deepspeed']
deepspeed info ................... 0.12.7+81cc3207, 81cc3207, xinji1
torch cuda version ............... 12.2
torch hip version ................ None
nvcc version ..................... 12.2
deepspeed wheel compiled w. ...... torch 2.1, cuda 12.2
shared memory (/dev/shm) size .... 216.50 GB

Screenshots If applicable, add screenshots to help explain your problem.

System info (please complete the following information):

Docker context Are you using a specific docker image that you can share? https://hub.docker.com/repository/docker/xinji1/tensorrt_llm/general

Additional context I've tried both causal=True and causal=False for the latter flash_attn_varlen_func, while cannot get the expected results.

xinji1 commented 9 months ago

Closed. Just found that my old flash_attn's version is 2.0.4. The bug is fixed after I update it to 2.4.2.