huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.92k stars 26.52k forks source link

Mamba slow implementation datatype mismatch #32690

Open Adibvafa opened 1 month ago

Adibvafa commented 1 month ago

System Info

transformers==4.44.0

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

  1. Setup a PyTorch Lightning training with float16 precision.
  2. Setup MambaConfig with use_mambapy=False
  3. Train MambaForCausalLM model

Expected behavior

The expected behavior is training proceeds with no issues. However, this is the error I got:

Code: cache_utils.py, MambaCache class, update_conv_state method

conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and Half for the source.

I was able to fix the issue by changing conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) to conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)

I will open a PR with the fix.

amyeroberts commented 1 month ago

cc @molbap too as I've seen you across some of the Mamba issues and PRs :)

ArthurZucker commented 1 week ago

BTW when you train you should not be using the cache_position so not super sure I understand what's going on! Unless you are in eval model, which would mean the init dtyp was wrong

Adibvafa commented 1 week ago

BTW when you train you should not be using the cache_position so not super sure I understand what's going on! Unless you are in eval model, which would mean the init dtyp was wrong

I was training using PyTorch Lightning, I didn't manually specify it to use cache_position. Is it unexpected that that chunk of code was ran in this setting?

ArthurZucker commented 1 week ago

Yeah, use_cache is probably set to True: https://github.com/huggingface/transformers/blob/21beb57558f90131439bf03e1da672d117777901/src/transformers/models/mamba/modeling_mamba.py#L603 is where we init the cache for mamba, and here is where you have the if else that calls https://github.com/huggingface/transformers/blob/21beb57558f90131439bf03e1da672d117777901/src/transformers/models/mamba/modeling_mamba.py#L248