neulab / guided_summarization

GSum: A General Framework for Guided Neural Abstractive Summarization
MIT License
113 stars 27 forks source link

About parameters in z_test.sh #35

Closed FearAlwaysWorks closed 2 years ago

FearAlwaysWorks commented 2 years ago

Hello, thank you for sharing your code. I'm trying to run your bart code recently, but I have a problem while running the z.test.sh

`------------------------------------------------------------------------------------------------------------------------------

Traceback (most recent call last): File "/home/delab30/code/WangXiaoye/guided_summarization/bart/z_test.py", line 25, in hypotheses_batch = bart.sample(slines, zlines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3, guided=True) File "/home/delab30/code/WangXiaoye/guided_summarization/bart/fairseq/models/bart/guided_hub_interface.py", line 125, in sample hypos = self.generate(input, z, beam, verbose, kwargs) File "/home/delab30/code/WangXiaoye/guided_summarization/bart/fairseq/models/bart/guided_hub_interface.py", line 142, in generate prefix_tokens=sample['net_input']['src_tokens'].newzeros((len(tokens), 1)).fill(self.task.source_dictionary.bos()), File "/home/delab30/code/WangXiaoye/guided_summarization/bart/fairseq/tasks/fairseq_task.py", line 354, in inference_step return generator.generate(models, sample, prefix_tokens=prefix_tokens) File "/home/delab30/anaconda3/envs/guided_summary/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context return func(args, kwargs) File "/home/delab30/code/WangXiaoye/guided_summarization/bart/fairseq/sequence_generator.py", line 852, in generate return self._generate(model, sample, kwargs) File "/home/delab30/anaconda3/envs/guided_summary/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context return func(args, kwargs) File "/home/delab30/code/WangXiaoye/guided_summarization/bart/fairseq/sequence_generator.py", line 1042, in _generate tokens[:, :step + 1], encoder_outs, z_encoder_outs, temperature=self.temperature, File "/home/delab30/anaconda3/envs/guided_summary/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context return func(args, kwargs) File "/home/delab30/code/WangXiaoye/guided_summarization/bart/fairseq/sequence_generator.py", line 583, in z_forward_decoder temperature=temperature, File "/home/delab30/code/WangXiaoye/guided_summarization/bart/fairseq/sequence_generator.py", line 639, in z_decode_one tokens, encoder_out=encoder_out, z_encoder_out=z_encoder_out, incremental_state=self.incremental_states[model], File "/home/delab30/code/WangXiaoye/guided_summarization/bart/fairseq/models/bart/guided_model.py", line 110, in z_forward_decoder decoder_out = self.decoder(prev_output_tokens, encoder_out, z_encoder_out=z_encoder_out, incremental_state=incremental_state, extra_args) File "/home/delab30/anaconda3/envs/guided_summary/lib/python3.6/site-packages/torch/nn/modules/module.py", line 744, in _call_impl result = self.forward(input, kwargs) File "/home/delab30/code/WangXiaoye/guided_summarization/bart/fairseq/models/guided_transformer.py", line 681, in forward alignment_heads=alignment_heads, File "/home/delab30/code/WangXiaoye/guided_summarization/bart/fairseq/models/guided_transformer.py", line 803, in extract_features need_head_weights=bool((idx == alignment_layer)), File "/home/delab30/anaconda3/envs/guided_summary/lib/python3.6/site-packages/torch/nn/modules/module.py", line 744, in _call_impl result = self.forward(*input, *kwargs) File "/home/delab30/code/WangXiaoye/guided_summarization/bart/fairseq/modules/guided_transformer_layer.py", line 197, in forward need_head_weights=need_head_weights, File "/home/delab30/anaconda3/envs/guided_summary/lib/python3.6/site-packages/torch/nn/modules/module.py", line 744, in _call_impl result = self.forward(input, kwargs) File "/home/delab30/code/WangXiaoye/guided_summarization/bart/fairseq/modules/multihead_attention.py", line 287, in forward assert key_padding_mask.size(0) == bsz AssertionError python-BaseException`

After checking the code and I found that the the shape of buffer('active_bbsz_idx') is 0(line 1218 in guided_summarization/bart/fairseq/sequence_generator.py), which may cause this problem. But I have no idea how to fix the code. Do you have any suggestions for me?

I would appreciate it if you could help me!

zdou0830 commented 2 years ago

hi, what's your running command?

FearAlwaysWorks commented 2 years ago

Hi, after training a BART model on cnndm dataset, I run the following command for testing.

SRC=/home/delab30/code/WangXiaoye/guided_summarization/bart/data/cnn_dm/matchsum_test/test.source GUIDANCE=/home/delab30/code/WangXiaoye/guided_summarization/bart/data/cnn_dm/matchsum_test/test.matchsum RESULT_PATH=/home/delab30/code/WangXiaoye/guided_summarization/bart/data/cnn_dm/matchsum_test/test_result MODEL_DIR=/home/delab30/code/WangXiaoye/guided_summarization/bart/models/cnn_dm MODEL_NAME=checkpoint_best.pt DATA_BIN=/home/delab30/code/WangXiaoye/guided_summarization/bart/data/cnn_dm/train_bin python z_test.py $SRC $GUIDANCE $RESULT_PATH $MODEL_DIR $MODEL_NAME $DATA_BIN

where $SRC and Guidance are downloaded from https://github.com/icml-2020-nlp/semsim. and https://drive.google.com/file/d/12SpWwfD3syIxcC-SdSNnDOI5sbXJaylC/view respectively. As for $DATA_BIN, I use the path of training data that produced by z_bin.sh before training (I don't know whether it's correct or not)

zdou0830 commented 2 years ago

Hi, I'm not sure what's going on, could you try using torch 1.4.0 and/or 1.5.0 to see if the issue could be solved?

FearAlwaysWorks commented 2 years ago

It seems that torch 1.4.0 or 1.5.0 is incompatible with my GPU... By the way, I wonder whether the sources and guidance for testing bart should be preprocessed by z_bpe.sh and z_bin.sh like training data.

zdou0830 commented 2 years ago

No, just raw texts should be fine.

FearAlwaysWorks commented 2 years ago

It seems that torch 1.5 is necessary. I think the problem is from the scaler type error of torch.mask_select(). So nice of you! Thank you for your help!