DachengLi1 / LongChat

Official repository for LongChat and LongEval
Apache License 2.0
504 stars 29 forks source link

OOM issue #28

Closed WeixuanXiong closed 1 year ago

WeixuanXiong commented 1 year ago

Excellent work!

My device is 8*A100 80G and i want to expand the context window size to 32k. When i set it to 32k, batchsize to 1, the OOM error occurred. I'm wondering if it will work if i add number of my gpus. Will the parameters of the model be sharded into every gpu i use?

Any advice will help me a lot.

DachengLi1 commented 1 year ago

@WeixuanXiong Please use flash attention? If you have already done so, the largest model I know with 8X80GB would be 13B. Beyond this model size gives OOM.

WeixuanXiong commented 1 year ago

I tried to apply the monkey patch of flash attention or xformers to my llama model, but it didn't make much of a difference. I noticed that even if I use two A100s with 8k context, I can still train, and the OOM seems to be unrelated to the number of GPUs. Using more GPUs also can't help me extend the context length. What confuses me is that my dataset is very small, the model is only 7B, and max_seq_len only affects the intermediate variables during training. Logically, zero 3 should distribute these variables among each GPU, but I can still fine-tune the model normally when I reduce the number of cards from 8 to 2. Should i truely input with max_seq_len=16k during training (truncate or padding), or just input 2k length text but set scale=8? I wonder if I have missed something.

DachengLi1 commented 1 year ago

@WeixuanXiong 7b-32k should definitely work on 8*80GB, we should aim at debugging to achieve this, instead of a shorter length. (1) please check flash attn is correctly used, there should be a huge difference. (2) consider using fsdp instead of zero3, they are the same, but many people report DS is hard to configure correctly.

WeixuanXiong commented 1 year ago

Thanks! Though no bugs report when using flashattn like this:

def _flash_attention(self, xq, keys, values, mask=None):
    bsz, seqlen = xq.shape[0], xq.shape[1]
    k_bsz, k_seqlen = keys.shape[0], keys.shape[1]

    cu_seqlens = torch.arange(0, (bsz+ 1) * seqlen, step=seqlen, dtype=torch.int32,
                              device=xq.device)
    k_cu_seqlens = torch.arange(0, (k_bsz+ 1) * k_seqlen, step=k_seqlen, dtype=torch.int32,
                              device=xq.device)
    scale = 1.0 / math.sqrt(self.head_dim)
    q = torch.reshape(xq, (xq.shape[0] * xq.shape[1], xq.shape[2], xq.shape[3]))
    k = torch.reshape(keys, (keys.shape[0] * keys.shape[1], keys.shape[2], keys.shape[3]))
    v = torch.reshape(values, (values.shape[0] * values.shape[1], values.shape[2], values.shape[3]))
    output = flash_attn_varlen_func(
        q, k, v, cu_seqlens, k_cu_seqlens, seqlen, k_seqlen,
        0.0,
        softmax_scale=scale, causal=mask is not None,
    )
    output = torch.reshape(output, (bsz, seqlen, -1))

    return output

def forward(
    self,
    x: torch.Tensor,
    kv_mask: torch.Tensor,
    freqs_cis: torch.Tensor,
    cache_k: Optional[torch.Tensor] = None,
    cache_v: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    start_pos = 0 

    bsz, seqlen, _ = x.shape
    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

    xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
    xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
    xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)

    xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

    # Modified code to allow training, caching is not good for training
    if (cache_k is None and cache_v is not None) or (
        cache_k is not None and cache_v is None
    ):
        raise ValueError("cache_k is None while cache_v is not None")
    if cache_k is None:
        keys = xk
        values = xv
    else:
        cache_k.to(xk.device)
        cache_v.to(xv.device)
        cache_k[:bsz, start_pos : start_pos + seqlen] = xk  # noqa E203
        cache_v[:bsz, start_pos : start_pos + seqlen] = xv  # noqa E203
        keys = self.cache_k[:bsz, : start_pos + seqlen]  # noqa E203
        values = self.cache_v[:bsz, : start_pos + seqlen]  # noqa E203

    # using flash attention
    if flash_attn_varlen_func:
        output = self._flash_attention(xq, keys, values, kv_mask)
    else:
        output = self._attention(xq, keys, values, kv_mask)

    if cache_k is None:
        return self.wo(output), None, None
    else:
        return self.wo(output), self.cache_k, self.cache_v        

Any advice on how to confirm flash attn is correctly used?