Open biirving opened 5 months ago
When running the command python unllama_token_clf.py conll2003 7b I get the following:
python unllama_token_clf.py conll2003 7b
RuntimeError: Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: float and query.dtype: c10::BFloat16 instead.
I am running on an A100, with cuda 12.1, transformers 4.37.2, and torch 2.1.2.
Could you downgrade transformers to 4.32.1? Or you can try this new repo with newer transformers: https://github.com/WhereIsAI/BiLLM
When running the command
python unllama_token_clf.py conll2003 7b
I get the following:RuntimeError: Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: float and query.dtype: c10::BFloat16 instead.
I am running on an A100, with cuda 12.1, transformers 4.37.2, and torch 2.1.2.