huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.92k stars 26.52k forks source link

[Mamba] Possible Issue in beam search with Mamba HF models #29730

Open Mirmu opened 6 months ago

Mirmu commented 6 months ago

System Info

nightly build transformers version

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda:0"
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m-hf").to(device)

input_prompts = ["Question: Who is the lead singer of Coldplay? Answer:",]
prompted_encoded = tokenizer(input_prompts[0:1], return_tensors="pt", padding=True).to(device)['input_ids']

with torch.inference_mode():
        generated_tokens = model.generate(input_ids=prompted_encoded, max_length=60, use_cache=True, num_beams=5)
        decoded_tokens = tokenizer.batch_decode(generated_tokens)
        print("Generation finished.")
        print(decoded_tokens)

Expected behavior

We are observing a degradation of the mamba model generated output as num_beams parameter goes up.

Example with num_beams=1

Generation finished.
['Question: Who is the lead singer of Coldplay? Answer: The lead singer of Coldplay is the lead singer of the band Coldplay.\n\nThe lead singer of Coldplay is the lead singer of the band Coldplay. The lead singer of Coldplay is the lead singer of the band']

Example with num_beams=100:

Generation finished.
['Question: Who is the lead singer of Coldplay? Answer: Coldplay\n\n\n Cold isfield\n\n\nQuestion:\nColdplayplayfield\n\n\n:\n\n:\n\n:\n\n:\n\n:\n\n:\n\n:\n\n:\n\n:\n']
ArthurZucker commented 6 months ago

I'll mark this as a feature request as I am not sure the original code had this either, and yes, beam generation donc look really good 😅 but that is probably because the logic of the state is different than the cache that is reordered. I'll see what I can do!

Mirmu commented 6 months ago

Thanks for the pointer, Arthur. The issue indeed seems coming from the re-ordering logic, which I could fix easily (handling is similar to the past_key_values one), with the following method:

    def reorder_cache(self, beam_idx):
        for layer_idx in self.conv_states:
            device = self.conv_states[layer_idx].device
            self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
            device = self.ssm_states[layer_idx].device
            self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
ArthurZucker commented 6 months ago

Feel free to open a PR for the fix!

javiermcebrian commented 5 months ago

Hi @Mirmu @ArthurZucker ! Is there any PR finally for this? thanks!