Open Mirmu opened 6 months ago
I'll mark this as a feature request as I am not sure the original code had this either, and yes, beam generation donc look really good 😅 but that is probably because the logic of the state is different than the cache that is reordered. I'll see what I can do!
Thanks for the pointer, Arthur. The issue indeed seems coming from the re-ordering logic, which I could fix easily (handling is similar to the past_key_values one), with the following method:
def reorder_cache(self, beam_idx):
for layer_idx in self.conv_states:
device = self.conv_states[layer_idx].device
self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
device = self.ssm_states[layer_idx].device
self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
Feel free to open a PR for the fix!
Hi @Mirmu @ArthurZucker ! Is there any PR finally for this? thanks!
System Info
nightly build transformers version
Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
We are observing a degradation of the mamba model generated output as
num_beams
parameter goes up.Example with num_beams=1
Example with num_beams=100: