Alex-Fabbri / Multi-News

Large-scale multi-document summarization dataset and code
Other
276 stars 53 forks source link

ERROR in running "run_inference_newser.sh" #18

Closed TysonYu closed 4 years ago

TysonYu commented 4 years ago

Hi, Thank you to provide the pre-trianed model. I downloaded the pre-trianed model in "newser-mmr" and pre-processed, truncated test data. When I try to use "run_inference_newser.sh" in "Hi-MAP" to get the result from test data set, there is an error like this: ''' Traceback (most recent call last): File "translate.py", line 37, in main(opt) File "translate.py", line 24, in main attn_debug=opt.attn_debug) File "/home/tiezheng/workspace/Debias/multidoc_Summarization/Multi-News/code/Hi_MAP/onmt/translate/translator.py", line 233, in translate batch_data = self.translate_batch(batch, data, fast=self.fast) File "/home/tiezheng/workspace/Debias/multidoc_Summarization/Multi-News/code/Hi_MAP/onmt/translate/translator.py", line 341, in translate_batch return self._translate_batch(batch, data) File "/home/tiezheng/workspace/Debias/multidoc_Summarization/Multi-News/code/Hi_MAP/onmt/translate/translator.py", line 622, in _translate_batch step=i) File "/home/tiezheng/anaconda3/envs/multi-news/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, **kwargs) File "/home/tiezheng/workspace/Debias/multidoc_Summarization/Multi-News/code/Hi_MAP/onmt/decoders/decoder.py", line 148, in forward tgt, memory_bank, state, memory_lengths=memory_lengths,sent_encoder=sent_encoder,src_sents=src_sents,dec=dec) File "/home/tiezheng/workspace/Debias/multidoc_Summarization/Multi-News/code/Hi_MAP/onmt/decoders/decoder.py", line 602, in _run_forward_pass mmr_among_words = self._run_mmr_attention(sent_encoder, sent_decoder, src_sents,attns["copy"][0].size()[-1]) File "/home/tiezheng/workspace/Debias/multidoc_Summarization/Multi-News/code/Hi_MAP/onmt/decoders/decoder.py", line 480, in _run_mmr_attention sim1 = torch.bmm(self.mmr_W(sent_decoder), sent.unsqueeze(2)).squeeze(2) # (2,1) RuntimeError: invalid argument 7: equal number of batches expected at /opt/conda/conda-bld/pytorch_1556653099582/work/aten/src/THC/generic/THCTensorMathBlas.cu:488 ''' Do you known what's the reason?

ziweiji commented 4 years ago

The shape of sent_decoder is [beam_size*batch_size, 1, dim] while the shape of sent(sent_encoder) is [batch_size, dim, 1]. I wander if MMR_Attention is compatible with Beam Search and how to solve the problem.

Alex-Fabbri commented 4 years ago

Hi I just updated the code with a small change which should allow beam search to work. Please let me know if you have any questions.