Open Adibvafa opened 1 month ago
cc @molbap too as I've seen you across some of the Mamba issues and PRs :)
BTW when you train you should not be using the cache_position
so not super sure I understand what's going on! Unless you are in eval model, which would mean the init dtyp was wrong
BTW when you train you should not be using the
cache_position
so not super sure I understand what's going on! Unless you are in eval model, which would mean the init dtyp was wrong
I was training using PyTorch Lightning, I didn't manually specify it to use cache_position
. Is it unexpected that that chunk of code was ran in this setting?
Yeah, use_cache
is probably set to True
: https://github.com/huggingface/transformers/blob/21beb57558f90131439bf03e1da672d117777901/src/transformers/models/mamba/modeling_mamba.py#L603 is where we init the cache for mamba, and here is where you have the if else that calls https://github.com/huggingface/transformers/blob/21beb57558f90131439bf03e1da672d117777901/src/transformers/models/mamba/modeling_mamba.py#L248
System Info
transformers==4.44.0
Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
MambaConfig
withuse_mambapy=False
MambaForCausalLM
modelExpected behavior
The expected behavior is training proceeds with no issues. However, this is the error I got:
Code: cache_utils.py, MambaCache class, update_conv_state method
I was able to fix the issue by changing
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
toconv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
I will open a PR with the fix.