Closed WeixuanXiong closed 1 year ago
@WeixuanXiong Please use flash attention? If you have already done so, the largest model I know with 8X80GB would be 13B. Beyond this model size gives OOM.
I tried to apply the monkey patch of flash attention or xformers to my llama model, but it didn't make much of a difference. I noticed that even if I use two A100s with 8k context, I can still train, and the OOM seems to be unrelated to the number of GPUs. Using more GPUs also can't help me extend the context length. What confuses me is that my dataset is very small, the model is only 7B, and max_seq_len only affects the intermediate variables during training. Logically, zero 3 should distribute these variables among each GPU, but I can still fine-tune the model normally when I reduce the number of cards from 8 to 2. Should i truely input with max_seq_len=16k during training (truncate or padding), or just input 2k length text but set scale=8? I wonder if I have missed something.
@WeixuanXiong 7b-32k should definitely work on 8*80GB, we should aim at debugging to achieve this, instead of a shorter length. (1) please check flash attn is correctly used, there should be a huge difference. (2) consider using fsdp instead of zero3, they are the same, but many people report DS is hard to configure correctly.
Thanks! Though no bugs report when using flashattn like this:
def _flash_attention(self, xq, keys, values, mask=None):
bsz, seqlen = xq.shape[0], xq.shape[1]
k_bsz, k_seqlen = keys.shape[0], keys.shape[1]
cu_seqlens = torch.arange(0, (bsz+ 1) * seqlen, step=seqlen, dtype=torch.int32,
device=xq.device)
k_cu_seqlens = torch.arange(0, (k_bsz+ 1) * k_seqlen, step=k_seqlen, dtype=torch.int32,
device=xq.device)
scale = 1.0 / math.sqrt(self.head_dim)
q = torch.reshape(xq, (xq.shape[0] * xq.shape[1], xq.shape[2], xq.shape[3]))
k = torch.reshape(keys, (keys.shape[0] * keys.shape[1], keys.shape[2], keys.shape[3]))
v = torch.reshape(values, (values.shape[0] * values.shape[1], values.shape[2], values.shape[3]))
output = flash_attn_varlen_func(
q, k, v, cu_seqlens, k_cu_seqlens, seqlen, k_seqlen,
0.0,
softmax_scale=scale, causal=mask is not None,
)
output = torch.reshape(output, (bsz, seqlen, -1))
return output
def forward(
self,
x: torch.Tensor,
kv_mask: torch.Tensor,
freqs_cis: torch.Tensor,
cache_k: Optional[torch.Tensor] = None,
cache_v: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
start_pos = 0
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
# Modified code to allow training, caching is not good for training
if (cache_k is None and cache_v is not None) or (
cache_k is not None and cache_v is None
):
raise ValueError("cache_k is None while cache_v is not None")
if cache_k is None:
keys = xk
values = xv
else:
cache_k.to(xk.device)
cache_v.to(xv.device)
cache_k[:bsz, start_pos : start_pos + seqlen] = xk # noqa E203
cache_v[:bsz, start_pos : start_pos + seqlen] = xv # noqa E203
keys = self.cache_k[:bsz, : start_pos + seqlen] # noqa E203
values = self.cache_v[:bsz, : start_pos + seqlen] # noqa E203
# using flash attention
if flash_attn_varlen_func:
output = self._flash_attention(xq, keys, values, kv_mask)
else:
output = self._attention(xq, keys, values, kv_mask)
if cache_k is None:
return self.wo(output), None, None
else:
return self.wo(output), self.cache_k, self.cache_v
Any advice on how to confirm flash attn is correctly used?
Excellent work!
My device is 8*A100 80G and i want to expand the context window size to 32k. When i set it to 32k, batchsize to 1, the OOM error occurred. I'm wondering if it will work if i add number of my gpus. Will the parameters of the model be sharded into every gpu i use?
Any advice will help me a lot.