shrimai / Focused-Attention-Improves-Document-Grounded-Generation

MIT License
21 stars 4 forks source link

About the num_return_sequences parameter in the generate function #3

Open pengshancai opened 3 years ago

pengshancai commented 3 years ago

Hi, I use the generation_utils.py to replace my own generation_utils.py in the transformers, however, I observed the following issue: When my parameter num_return_sequences of the generate() function is greater than 1, the program would have the following error:

File "/root/anaconda3/envs/wow/lib/python3.7/site-packages/transformers/generation_utils.py", line 413, in generate attention_mask = attention_mask.unsqueeze(1).expand( AttributeError: 'tuple' object has no attribute 'unsqueeze'

It seems the error occurs as the attention_mask is a tuple (attention_mask=(source_mask, doc_mask)) instead of a single matrix.

Wish to know if there is any quick fix to that?