wangpinggl / TREQS

Text-to-SQL Generation for Question Answering on Electronic Medical Records
MIT License
121 stars 29 forks source link

I found a bug in the `word_copy()` method. #5

Open wcshin-git opened 3 years ago

wcshin-git commented 3 years ago

Hi, I found a bug in the word_copy() in model_deq2seq_base.py.

The purpose of word_copy() is to replace the generated <unk> with a token in the source text. So we need an attn score how much the <unk> attend the source tokens.

def word_copy(self):
    '''
    copy words from source document.
    '''
    myseq = torch.cat(self.beam_data[0][0], 0)                 # eg. [13]
    myattn = torch.cat(self.beam_data[0][-1]['accu_attn'], 0)  # eg. [66, 400].
    myattn = myattn*self.batch_data['src_mask_unk']  
    beam_copy = myattn.topk(k=1, dim=1)[1].squeeze(-1)
    wdidx = beam_copy.data.cpu().numpy()
    out_txt = []
    myseq = torch.cat(self.beam_data[0][0], 0)
    myseq = myseq.data.cpu().numpy().tolist()
    gen_txt = [self.batch_data['id2vocab'][wd]
                if wd in self.batch_data['id2vocab'] 
                else self.batch_data['ext_id2oov'][wd] 
                for wd in myseq]
    for j in range(len(gen_txt)):
        if gen_txt[j] == '<unk>':
            gen_txt[j] = self.batch_data['src_txt'][0][wdidx[j]] 
    out_txt.append(' '.join(gen_txt))                             

    return out_txt

But wdidx[j] is not correct correspondence of gen_txt[j] when using beam search. It's correct correspondence when beam size = 1. The reason is that myseq.shape[0] != myattn.shape[0]+1(for eg, 13 != 66+1. If you set the beam size 1, myattn.shape[0] will be 12) If you set beam size > 1, there won't be an error, but the copy mechanism doesn't work properly.

wangpinggl commented 3 years ago

We used the NATS package for all the experiments in the paper. Please check https://github.com/tshi04/NATS. This is a new implementation of the codes and results in README are obtained based on this version. Thanks for pointing the problem out and we have fixed it.