jzhang38 / EasyContext

Memory optimization and training recipes to extrapolate language models' context length to 1 million tokens, with minimal hardware.
Apache License 2.0
529 stars 33 forks source link

Error when the model vocabulary is larger than 120k #14

Closed microhu closed 2 months ago

microhu commented 2 months ago

Dear author,

when using the easycontext to evaluate a model with more than 120K vocabuary, I come across an strange problem:

the predict result is correct when a single gpu is used, but error when 8 gpus are used. the 'pred' result in low code are all zeros, which is quite stange.

I wonder is there any limit of the vocabulary size in ring-attention implementation and what is the possible reason?

BTW, below is what i have tried:

  1. i tried the model your provided, it is correct on both single and multi-gpu mode.
  2. To eliminate the issue with the tokenizer, I am using input_ids as the input in eval_forward func. (consistent with the input_ids for multi-GPU inference and single-GPU inference)

    with torch.inference_mode(): logits = self.model( local_input_ids, position_ids=local_position_ids, use_cache=False, ).logits pred = logits.argmax(dim=-1)

the value of pred ids: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:6')Traceback (most recent call last):

jzhang38 commented 2 months ago

when using the easycontext to evaluate a model with more than 120K vocabuary

Which model are you using? We only support llama architecture for now.

https://github.com/jzhang38/EasyContext/blob/f81973b1caea65e8480df0ad619f420711411a09/eval_needle.py#L22

jzhang38 commented 2 months ago

To add on, many opensource models nowadays (yi/qwenvl ...) can be converted into Llama architecture. Just make sure you load them as LlamaForCausalLM.

microhu commented 2 months ago

To add on, many opensource models nowadays (yi/qwenvl ...) can be converted into Llama architecture. Just make sure you load them as LlamaForCausalLM.

Yeah, it is LlamaForCausalLM, but the tokenizer is customized by ourself to support a much larger vocabulary size.

I also tried:

  1. By using the default flash-attention2, running on 8 gpus is correct.
  2. with 'zigzag_ring_attn', if running on 2 gpus, it is correct, but with 3/4/8 cards, the prediction is all zero. quete strange.

Seems there is something wrong in dist communication ?

microhu commented 2 months ago

add more

I add some print info before and after calling self-attn function in the monkey_patch.py (see below codes in detail) and find that at about the 14 layer, the outputs of self_attn are 'nan'. I tried different models configured with the same 120K vocabs, the 'nan' values are appeared in almost the 14-15 layers. Any idea of this strange error ? @jzhang38


print('the last 10 values of hidden states before self_attn {0}'.format(hidden_states[:,:,-10:]))

Self Attention

hidden_states, self_attn_weights, present_key_value = self.self_attn(
    hidden_states=hidden_states,
    attention_mask=attention_mask,
    position_ids=position_ids,
    past_key_value=past_key_value,
    output_attentions=output_attentions,
    use_cache=use_cache,
    cache_position=cache_position,
    **kwargs,
)

print('the last 10 values of hidden states after self_attn {0}'.format(hidden_states[:,:,-10:]))

----------- the print info---------------- the last 10 values of hidden states before self_attn tensor([[[ 0.1074, 0.6445, -0.3945, ..., -0.2256, 0.1797, 0.4160], [-0.2070, 0.3789, -0.3262, ..., -0.0786, 0.0825, 0.3125], [ 0.1270, 0.6250, -0.2109, ..., -0.2480, 0.2373, 0.3809], ..., [ 0.3633, 0.2891, 0.1797, ..., -0.4258, 0.1650, -0.0198], [ 0.3672, 0.4453, 0.2012, ..., -0.3535, 0.3047, 0.0481], [ 0.5352, 0.3223, 0.0894, ..., -0.3320, 0.7070, 0.0143]]], device='cuda:1', dtype=torch.bfloat16) the last 10 values of hidden states before self_attn tensor([[[ 0.4141, 0.2734, 0.0825, ..., 0.1396, 0.4219, 0.3984], [ 0.3809, -0.1226, -0.0160, ..., -0.0737, 0.2129, 0.3164], [ 0.3223, 0.0947, -0.0527, ..., -0.1748, 0.2129, 0.3340], ..., [ 0.2471, 0.2148, 0.1040, ..., -0.4258, 0.2236, -0.0635], [ 0.3027, 0.3945, 0.1387, ..., -0.3359, 0.3691, 0.0947], [ 0.4766, 0.2930, 0.0447, ..., -0.3203, 0.7305, 0.0126]]], device='cuda:5', dtype=torch.bfloat16) the last 10 values of hidden states after self_attn tensor([[[ nan, nan, nan, ..., nan, nan, nan], [ nan, nan, nan, ..., nan, nan, nan], [ nan, nan, nan, ..., nan, nan, nan], ..., [-0.3340, -0.0013, -0.3496, ..., -0.4707, -0.4551, -0.5234], [ 0.0747, 0.1836, 0.3652, ..., -0.3750, -0.7539, -0.4531], [-0.4141, 0.2051, 0.9102, ..., -0.4238, -1.1953, -0.4453]]], device='cuda:7', dtype=torch.bfloat16) the last 10 values of hidden states after self_attn tensor([[[ nan, nan, nan, ..., nan, nan, nan], [ nan, nan, nan, ..., nan, nan, nan], [ 0.2754, 0.4375, -0.1777, ..., -0.1108, -1.1562, -0.1797], ..., [ 0.4668, 0.2041, 0.1973, ..., 0.6641, -1.7500, -0.2188], [ 0.0452, -0.0967, -0.5352, ..., 0.5234, -1.1875, 0.1514], [ 0.2041, 0.9570, -0.2393, ..., 0.3477, -1.5625, -0.0337]]], device='cuda:4', dtype=torch.bfloat16)

jzhang38 commented 2 months ago

It is probably caused by the numerical error in the current zigzag_ring_attn's implementation:

https://github.com/zhuzilin/ring-flash-attention?tab=readme-ov-file#limits

"There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the values with the original fp32 ones."

microhu commented 2 months ago

It is probably caused by the numerical error in the current zigzag_ring_attn's implementation:

https://github.com/zhuzilin/ring-flash-attention?tab=readme-ov-file#limits

"There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the values with the original fp32 ones."

I see. Any plan to fix it?

jzhang38 commented 2 months ago

It would be nice if you could give me some detailed instructions (including model weights) about how to reproduce your error first.

microhu commented 2 months ago

It would be nice if you could give me some detailed instructions (including model weights) about how to reproduce your error first.

Thanks for your kindly help. Finally, I figured out the problem. It's not by bf16 or fp32, but the torch.exp overflow in ring-flash-attention.

Line 19 "new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) " in file https://github.com/zhuzilin/ring-flash-attention/blob/main/ring_flash_attn/utils.py

I fixed it by using logsigmoid:

_logSigmoid = torch.nn.LogSigmoid() new_lse = lse - _logSigmoid(lse - block_lse)

jzhang38 commented 2 months ago

@microhu Glad you resolved the issue! Would you consider submitting a PR to the ring-flash-attention repo? That would be very helpful!.

microhu commented 2 months ago

@microhu Glad you resolved the issue! Would you consider submitting a PR to the ring-flash-attention repo? That would be very helpful!.

Done