wasiahmad / NeuralCodeSum

Official implementation of our work, A Transformer-based Approach for Source Code Summarization [ACL 2020].
MIT License
192 stars 79 forks source link

Split decoder #28

Closed MatteoZhang closed 3 years ago

MatteoZhang commented 3 years ago

Hi, is the split decoder part implemented? I tried ur code with argument args.split_decoder True and got this error:

Epoch = 1 [perplexity = x.xx, ml_loss = x.xx]: 0% 0/939 [00:00<?, ?it/s]Traceback (most recent call last): File "../../main/train.py", line 708, in main(args) File "../../main/train.py", line 653, in main train(args, train_loader, model, stats) File "../../main/train.py", line 283, in train net_loss = model.update(ex) File ".../model.py", line 173, in update example_weights=ex_weights) File ".../module.py", line 889, in _call_impl result = self.forward(input, kwargs) File ".../transformer.py", line 435, in forward kwargs) File ".../transformer.py", line 363, in _run_forward_ml summ_emb) File ".../module.py", line 889, in _call_impl result = self.forward(input, kwargs) File "...transformer.py", line 295, in forward return self.decode(tgt_pad_mask, tgt_emb, memory_bank, state) File "...transformer.py", line 273, in decode f_t = self.fusion_sigmoid(torch.cat([copier_out, dec_out], dim=-1)) TypeError: expected Tensor as element 0 in argument 0, but got list** Epoch = 1 [perplexity = x.xx, ml_loss = x.xx]: 0% 0/939 [00:01<?, ?it/s]

should we use torch.stack() ??? thanks in advance

wasiahmad commented 3 years ago

Yes, the split decoder mechanism is implemented, however, I found a mistake that is causing the error. The error is due to copier_out being a list of Tensor, instead of a Tensor. As you can see here, TransformerDecoder returns a list of representations where each item is the output from each layer of the TransformerDecoder. So, to make the split decoder mechanism work, you need to make 2 changes.

First, in the following line,

f_t = self.fusion_sigmoid(torch.cat([copier_out, dec_out], dim=-1))

copier_out and dec_out are both lists. So, modify the line as follows.

f_t = self.fusion_sigmoid(torch.cat([copier_out[-1], dec_out[-1]], dim=-1))            

Second,

When split decoder mechanism is enabled, decoder_outputs [as in here] is no longer a list of Tensor, it is only a Tensor. Therefore, to cope with the statements 1 and 2, you can simply do:

decoder_outputs = self.fusion_gate(gate_input)
decoder_outputs = [decoder_outputs]

Hopefully, this would work but please verify it. A pull request is welcomed.