Closed DreamInvoker closed 2 months ago
Yes you need to use this modeling file: https://github.com/ContextualAI/gritlm/blob/main/scripts/modeling_mistral_gritlm.py
for Mistral.
If you want to train a different model then you need to add is_causal
to its kwargs in the forward and use bidirectional attention whenever it is False.
I detailed it here: https://github.com/ContextualAI/gritlm?tab=readme-ov-file#models but maybe it is not visible enough
thank you! should I implement _prepare_4d_causal_attention_mask function for another model?
I think you can just import it from transformer like done in the mistral modeling file. You then just need to add an if/else clause like here https://github.com/ContextualAI/gritlm/blob/01b2ed4799f09df2d9d41a9870176606b32568aa/scripts/modeling_mistral_gritlm.py#L1032
get, thank you very much!
when I set attn=bbcc, I encountered this bug. Do you have some advice?
torch==2.0.1 transformers==4.37.2