microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
34.62k stars 4.04k forks source link

[BUG] Flash attention seems cannot be integrated with pipeline parallelism due to absent of input grad #3868

Open SparkJiao opened 1 year ago

SparkJiao commented 1 year ago

Describe the bug Flash attention of both implementations from the original one or the torch.nn.functional.scaled_dot_production from pytorch2.0 cannot be integrated with LLaMA pipeline parallelism training.

ds_report output

--------------------------------------------------                                                                                                                                                                 │+-----------------------------------------------------------------------------+
DeepSpeed C++/CUDA extension op report                                                                                                                                                                             │| Processes:                                                                  |
--------------------------------------------------                                                                                                                                                                 │|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
NOTE: Ops not installed will be just-in-time (JIT) compiled at                                                                                                                                                     │|        ID   ID                                                   Usage      |
      runtime if needed. Op compatibility means that your system                                                                                                                                                   │|=============================================================================|
      meet the required dependencies to JIT install the op.                                                                                                                                                        │|    0   N/A  N/A   1880877      C   python3                          2525MiB |
--------------------------------------------------                                                                                                                                                                 │|    1   N/A  N/A   1879942      C   ...s/pytorch_scse/bin/python    40461MiB |
JIT compiled ops requires ninja                                                                                                                                                                                    │|    2   N/A  N/A   1030504      C   ...envs/wespeaker/bin/python    31351MiB |
ninja .................. [OKAY]                                                                                                                                                                                    │|    4   N/A  N/A      1663      C   ...s/pytorch_scse/bin/python    40461MiB |
--------------------------------------------------                                                                                                                                                                 │|    5   N/A  N/A   1030505      C   ...envs/wespeaker/bin/python    31351MiB |
op name ................ installed .. compatible                                                                                                                                                                   │|    6   N/A  N/A   2442536      C   ...s/pytorch_scse/bin/python    40461MiB |
--------------------------------------------------                                                                                                                                                                 │+-----------------------------------------------------------------------------+
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.                                                                                                                      │(base) fangkai@scsehg:~$
 [WARNING]  async_io: please install the libaio-dev package with apt                                                                                                                                               │
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.                                                               │
async_io ............... [NO] ....... [NO]                                                                                                                                                                         │
cpu_adagrad ............ [NO] ....... [OKAY]                                                                                                                                                                       │
cpu_adam ............... [NO] ....... [OKAY]                                                                                                                                                                       │
fused_adam ............. [NO] ....... [OKAY]                                                                                                                                                                       │
fused_lamb ............. [NO] ....... [OKAY]                                                                                                                                                                       │
quantizer .............. [NO] ....... [OKAY]                                                                                                                                                                       │
random_ltd ............. [NO] ....... [OKAY]                                                                                                                                                                       │
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.0                                                                                                                                 │
 [WARNING]  using untested triton version (2.0.0), only 1.0.0 is known to be compatible                                                                                                                            │
sparse_attn ............ [NO] ....... [NO]                                                                                                                                                                         │
spatial_inference ...... [NO] ....... [OKAY]                                                                                                                                                                       │
transformer ............ [NO] ....... [OKAY]                                                                                                                                                                       │
stochastic_transformer . [NO] ....... [OKAY]                                                                                                                                                                       │
transformer_inference .. [NO] ....... [OKAY]                                                                                                                                                                       │
--------------------------------------------------                                                                                                                                                                 │
No CUDA runtime is found, using CUDA_HOME='/cm/shared/apps/cuda11.6/toolkit/11.6.0'                                                                                                                                │
DeepSpeed general environment info:                                                                                                                                                                                │
torch install path ............... ['/export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/torch']                                                                                             │
torch version .................... 2.0.0+cu117                                                                                                                                                                     │
deepspeed install path ........... ['/export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/deepspeed']                                                                                         │
deepspeed info ................... 0.9.5, unknown, unknown                                                                                                                                                         │
torch cuda version ............... 11.7                                                                                                                                                                            │
torch hip version ................ None                                                                                                                                                                            │
nvcc version ..................... 11.6                                                                                                                                                                            │
deepspeed wheel compiled w. ...... torch 2.0, cuda 11.7

Screenshots The error information is as follows:

│ /export/home2/fangkai/merit-v2/trainer_base_ds_mp.py:408 in main             │                                                                                                                                   │+-------------------------------+----------------------+----------------------+
│                                                                              │                                                                                                                                   │|   5  NVIDIA A100-PCI...  On   | 00000000:A1:00.0 Off |                    0 |
│   405 │   │   │   logger.info("Resuming training from the latest checkpoint: │                                                                                                                                   │| N/A   65C    P0   234W / 250W |  31354MiB / 40960MiB |     96%      Default |
│   406 │   │   │   continue_from_global_step = int(checkpoint.split('-')[-1]) │                                                                                                                                   │|                               |                      |             Disabled |
│   407 │   │                                                                  │                                                                                                                                   │+-------------------------------+----------------------+----------------------+
│ ❱ 408 │   │   global_step, tr_loss = train(cfg, model_pipe, tokenizer, conti │                                                                                                                                   │|   6  NVIDIA A100-PCI...  On   | 00000000:C1:00.0 Off |                    0 |
│   409 │   │   logger.info(" global_step = %s, average loss = %s", global_ste │                                                                                                                                   │| N/A   53C    P0    95W / 250W |  40464MiB / 40960MiB |     62%      Default |
│   410                                                                        │                                                                                                                                   │|                               |                      |             Disabled |
│   411                                                                        │                                                                                                                                   │+-------------------------------+----------------------+----------------------+
│                                                                              │                                                                                                                                   │|   7  NVIDIA A100-PCI...  On   | 00000000:E1:00.0 Off |                    0 |
│ /export/home2/fangkai/merit-v2/trainer_base_ds_mp.py:298 in train            │                                                                                                                                   │| N/A   28C    P0    33W / 250W |      0MiB / 40960MiB |      0%      Default |
│                                                                              │                                                                                                                                   │|                               |                      |             Disabled |
│   295 │   │   │   │   │   continue                                           │                                                                                                                                   │+-------------------------------+----------------------+----------------------+
│   296 │   │   │   │                                                          │                                                                                                                                   │
│   297 │   │   │   │   model.train()                                          │                                                                                                                                   │+-----------------------------------------------------------------------------+
│ ❱ 298 │   │   │   │   loss = model.train_batch(data_iter=sub_train_dataloade │                                                                                                                                   │| Processes:                                                                  |
│   299 │   │   │   │   global_step += 1                                       │                                                                                                                                   │|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
│   300 │   │   │   │                                                          │                                                                                                                                   │|        ID   ID                                                   Usage      |
│   301 │   │   │   │   tr_loss += loss.item()                                 │                                                                                                                                   │|=============================================================================|
│                                                                              │                                                                                                                                   │|    0   N/A  N/A   1880877      C   python3                          2525MiB |
│ /export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/de │                                                                                                                                   │|    1   N/A  N/A   1879942      C   ...s/pytorch_scse/bin/python    40461MiB |
│ epspeed/runtime/pipe/engine.py:336 in train_batch                            │                                                                                                                                   │|    2   N/A  N/A   1030504      C   ...envs/wespeaker/bin/python    31351MiB |
│                                                                              │                                                                                                                                   │|    4   N/A  N/A      1663      C   ...s/pytorch_scse/bin/python    40461MiB |
│    333 │   │   sched = schedule.TrainSchedule(micro_batches=self.micro_batch │                                                                                                                                   │|    5   N/A  N/A   1030505      C   ...envs/wespeaker/bin/python    31351MiB |
│    334 │   │   │   │   │   │   │   │   │      stages=self.num_stages,        │                                                                                                                                   │|    6   N/A  N/A   2442536      C   ...s/pytorch_scse/bin/python    40461MiB |
│    335 │   │   │   │   │   │   │   │   │      stage_id=self.stage_id)        │                                                                                                                                   │+-----------------------------------------------------------------------------+
│ ❱  336 │   │   self._exec_schedule(sched)                                    │                                                                                                                                   │(base) fangkai@scsehg:~$
│    337 │   │   self.agg_train_loss = self._aggregate_total_loss()            │                                                                                                                                   │
│    338 │   │                                                                 │                                                                                                                                   │
│    339 │   │   self.timers('train_batch').stop()                             │                                                                                                                                   │
│                                                                              │                                                                                                                                   │
│ /export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/de │                                                                                                                                   │
│ epspeed/runtime/pipe/engine.py:1307 in _exec_schedule                        │                                                                                                                                   │
│                                                                              │                                                                                                                                   │
│   1304 │   │   │   │                                                         │                                                                                                                                   │
│   1305 │   │   │   │   # Equivalent to: self._exec_forward_pass(buffer_id=0) │                                                                                                                                   │
│   1306 │   │   │   │   self._exec_instr = MethodType(self._INSTRUCTION_MAP[t │                                                                                                                                   │
│ ❱ 1307 │   │   │   │   self._exec_instr(**cmd.kwargs)                        │                                                                                                                                   │
│   1308                                                                       │                                                                                                                                   │
│                                                                              │                                                                                                                                   │
│ /export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/de │                                                                                                                                   │
│ epspeed/runtime/pipe/engine.py:996 in _exec_send_grads                       │                                                                                                                                   │
│                                                                              │                                                                                                                                   │
│    993 │   │   │   │   │   if not buffer.is_floating_point():                │                                                                                                                                   │
│    994 │   │   │   │   │   │   assert buffer.grad is None                    │                                                                                                                                   │
│    995 │   │   │   │   │   │   continue                                      │                                                                                                                                   │
│ ❱  996 │   │   │   │   │   assert buffer.grad is not None                    │                                                                                                                                   │
│    997 │   │   │   │   │   p2p.send(buffer.grad, self.prev_stage)            │                                                                                                                                   │
│    998 │   │                                                                 │                                                                                                                                   │
│    999 │   │   # We can free up the input buffer now                         │                                                                                                                                   │
╰──────────────────────────────────────────────────────────────────────────────╯                                                                                                                                   │
AssertionError

System info (please complete the following information):

Launcher context deepspeed launcher

The code for implementing flash attention in my own project is as follows:

""" https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py.
"""

from typing import List, Optional, Tuple, Dict

import torch
import transformers
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input

def smart_tokenizer_and_embedding_resize(
        special_tokens_dict: Dict,
        tokenizer: transformers.PreTrainedTokenizer,
        model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    # TODO: padding embedding size for being divisible by 64.
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg

def llama_flash_attn_forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    """Input shape: Batch x Time x Channel

    attention_mask: [bsz, q_len]
    """
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    # [bsz, q_len, nh, hd]
    # [bsz, nh, q_len, hd]

    kv_seq_len = key_states.shape[-2]
    assert past_key_value is None, "past_key_value is not supported"

    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
    # [bsz, nh, t, hd]
    assert not output_attentions, "output_attentions is not supported"
    assert not use_cache, "use_cache is not supported"

    # Flash attention codes from
    # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py

    # transform the data into the format required by flash attention
    qkv = torch.stack([query_states, key_states, value_states], dim=2)  # [bsz, nh, 3, q_len, hd]
    qkv = qkv.transpose(1, 3)  # [bsz, q_len, 3, nh, hd]
    # We have disabled _prepare_decoder_attention_mask in LlamaModel
    # the attention_mask should be the same as the key_padding_mask
    attention_mask = torch.ones((bsz, q_len), device=qkv.device)
    key_padding_mask = attention_mask

    if key_padding_mask is None:
        qkv = rearrange(qkv, 'b s ... -> (b s) ...')
        max_s = q_len
        cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32,
                                 device=qkv.device)
        output = flash_attn_unpadded_qkvpacked_func(
            qkv, cu_q_lens, max_s, 0.0,
            softmax_scale=None, causal=True
        )
        output = rearrange(output, '(b s) ... -> b s ...', b=bsz)
    else:
        nheads = qkv.shape[-2]
        x = rearrange(qkv, 'b s three h d -> b s (three h d)')
        x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
        x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
        output_unpad = flash_attn_unpadded_qkvpacked_func(
            x_unpad, cu_q_lens, max_s, 0.0,
            softmax_scale=None, causal=True
        )
        output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
                                     indices, bsz, q_len), 'b s (h d) -> b s h d', h=nheads)
    return self.o_proj(rearrange(output, 'b s h d -> b s (h d)')), None, None

def llama_flash_attn_forward_pytorch(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    """Input shape: Batch x Time x Channel

    attention_mask: [bsz, q_len]
    """
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    # [bsz, q_len, nh, hd]
    # [bsz, nh, q_len, hd]

    kv_seq_len = key_states.shape[-2]
    assert past_key_value is None, "past_key_value is not supported"

    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
    # [bsz, nh, t, hd]
    assert not output_attentions, "output_attentions is not supported"
    assert not use_cache, "use_cache is not supported"

    with torch.backends.cuda.sdp_kernel(
            enable_flash=True,
            enable_math=False,
            enable_mem_efficient=False
    ):
        out = F.scaled_dot_product_attention(
            query_states, key_states, value_states,
            is_causal=True,
        )

    out = out.transpose(1, 2)
    out = out.reshape(bsz, q_len, self.hidden_size)

    return out, None, None

# Just hack here by calling the following method at main.py.
def replace_llama_attn_with_flash_attn():
    transformers.models.llama.modeling_llama.LlamaAttention.forward = llama_flash_attn_forward
    # transformers.models.llama.modeling_llama.LlamaAttention.forward = llama_flash_attn_forward_pytorch
sx1999 commented 7 months ago

Hello, I also encountered this problem, could you tell me how to solve it?