huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.74k stars 26.95k forks source link

modeling_llama - LlamaAttention attempts to subscript `None` position_ids #22407

Closed cheald closed 1 year ago

cheald commented 1 year ago

System Info

Who can help?

@gante

Information

Tasks

Reproduction

When trying to convert llama weights with https://github.com/qwopqwop200/GPTQ-for-LLaMa I encountered the following:

❯ CUDA_VISIBLE_DEVICES=0 python llama.py ./models/hf/13B/llama-13b c4 --wbits 4 --true-sequential --act-order --new-eval --save_safetensors llama-13b-4bit.safetensors
Starting ...
Ready.
Traceback (most recent call last):
  File "./GPTQ-for-LLaMA/llama.py", line 449, in <module>
    quantizers = llama_sequential(model, dataloader, DEV)
  File "./GPTQ-for-LLaMA/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "./GPTQ-for-LLaMA/llama.py", line 100, in llama_sequential
    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
  File "./GPTQ-for-LLaMA/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "./GPTQ-for-LLaMA/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 311, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "./GPTQ-for-LLaMA/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "./GPTQ-for-LLaMA/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 220, in forward
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  File "./GPTQ-for-LLaMA/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 132, in apply_rotary_pos_emb
    gather_indices = position_ids[:, None, :, None]  # [bs, 1, seq_len, 1]
TypeError: 'NoneType' object is not subscriptable

This appears to be due to a recent change in 7dcd8703ef904adc3ac19b47f769879221c33849 - LlamaAttention passes position_ids to apply_rotary_pos_emb, but defaults them to None and does not generate them if missing (unlike LlamaModel, which appears to generate them).

Expected behavior

None position_ids should not be passed to apply_rotary_pos_emb.

I'm not quite sure of what the right fix here is, but at a minimum, I suspect that if the caller is expected to provide them, defaulting to None is incorrect.

cheald commented 1 year ago

Sorry, I failed to autocomplete @gante 's handle on the inital ticket. Adding a comment for the tag.

gante commented 1 year ago

Hey @cheald 👋

For context, position_ids is required for correct behavior with left-padding, which in turn is needed for batched generation. Having a look at the issue!

cheald commented 1 year ago

Yup. I don't have the context to grok the proper place to be creating and passing them, but it seems like an interface error, at the minimum, to make a parameter optional and then use it non-optionally.

gante commented 1 year ago

@cheald The issue stems from the GPTQ-for-Llama package, which should catch all intermediary inputs for proper quantization. I've opened an issue there. You can follow it and make the corresponding local changes, which should work 🤗

However, the ball is on their side -- the changes we made are retrocompatible with our public API and, while we avoid creating these sort of issues, we have no bandwidth to fix problems regarding the use of internal variables/methods.

Is there anything else I can help you with? :)

cheald commented 1 year ago

All good. I'd suggest that an interface change to LlamaAttention to remove the None default value for position_ids would be appropriate, making the parameter required; it seems like a bit of a landmine to have a nominally optional argument which causes an exception if it's not provided (or, perhaps, at least an explicit check and exception if they're missing).

If the answer is "no, for the purposes of API compatibility", then that's fine, but at least then this ticket might help the next person to run into it!

Thanks so much - I realize this is cut-myself-on-the-bleeding-edge stuff, but I appreciate the swift help!

gante commented 1 year ago

@cheald Due to Llama's popularity, I've made an exception -- this PR should make it retrocompatible. Would you be able to test it on your end? 🤗

cheald commented 1 year ago

I'll test it in a bit. Thank you so much (for this, and for all the amazing work you do on the transformers project!)

cheald commented 1 year ago

My quantization pass is still running (it takes quite some time), but it appears this is working as intended. Thank you! :tada:

gante commented 1 year ago

@cheald hehe it turns out it is no longer needed, as the maintainers of GPTQ-for-Llama have pushed a fix on their end!