ContextualAI / gritlm

Generative Representational Instruction Tuning
https://arxiv.org/abs/2402.09906
MIT License
479 stars 33 forks source link

TypeError: forward() got an unexpected keyword argument 'is_causal' #24

Closed DreamInvoker closed 2 months ago

DreamInvoker commented 2 months ago

when I set attn=bbcc, I encountered this bug. Do you have some advice?

torch==2.0.1 transformers==4.37.2

Muennighoff commented 2 months ago

Yes you need to use this modeling file: https://github.com/ContextualAI/gritlm/blob/main/scripts/modeling_mistral_gritlm.py for Mistral. If you want to train a different model then you need to add is_causal to its kwargs in the forward and use bidirectional attention whenever it is False.

Muennighoff commented 2 months ago

I detailed it here: https://github.com/ContextualAI/gritlm?tab=readme-ov-file#models but maybe it is not visible enough

DreamInvoker commented 2 months ago

thank you! should I implement _prepare_4d_causal_attention_mask function for another model?

Muennighoff commented 2 months ago

I think you can just import it from transformer like done in the mistral modeling file. You then just need to add an if/else clause like here https://github.com/ContextualAI/gritlm/blob/01b2ed4799f09df2d9d41a9870176606b32568aa/scripts/modeling_mistral_gritlm.py#L1032

DreamInvoker commented 2 months ago

get, thank you very much!