Closed MatteoZhang closed 3 years ago
Yes, the split decoder mechanism is implemented, however, I found a mistake that is causing the error. The error is due to copier_out
being a list of Tensor, instead of a Tensor. As you can see here, TransformerDecoder returns a list of representations where each item is the output from each layer of the TransformerDecoder. So, to make the split decoder mechanism work, you need to make 2 changes.
First, in the following line,
f_t = self.fusion_sigmoid(torch.cat([copier_out, dec_out], dim=-1))
copier_out
and dec_out
are both lists. So, modify the line as follows.
f_t = self.fusion_sigmoid(torch.cat([copier_out[-1], dec_out[-1]], dim=-1))
Second,
When split decoder mechanism is enabled, decoder_outputs
[as in here] is no longer a list of Tensor, it is only a Tensor. Therefore, to cope with the statements 1 and 2, you can simply do:
decoder_outputs = self.fusion_gate(gate_input)
decoder_outputs = [decoder_outputs]
Hopefully, this would work but please verify it. A pull request is welcomed.
Hi, is the split decoder part implemented? I tried ur code with argument args.split_decoder True and got this error:
Epoch = 1 [perplexity = x.xx, ml_loss = x.xx]: 0% 0/939 [00:00<?, ?it/s]Traceback (most recent call last): File "../../main/train.py", line 708, in
main(args)
File "../../main/train.py", line 653, in main
train(args, train_loader, model, stats)
File "../../main/train.py", line 283, in train
net_loss = model.update(ex)
File ".../model.py", line 173, in update
example_weights=ex_weights)
File ".../module.py", line 889, in _call_impl
result = self.forward(input, kwargs)
File ".../transformer.py", line 435, in forward
kwargs)
File ".../transformer.py", line 363, in _run_forward_ml
summ_emb)
File ".../module.py", line 889, in _call_impl
result = self.forward(input, kwargs)
File "...transformer.py", line 295, in forward
return self.decode(tgt_pad_mask, tgt_emb, memory_bank, state)
File "...transformer.py", line 273, in decode
f_t = self.fusion_sigmoid(torch.cat([copier_out, dec_out], dim=-1))
TypeError: expected Tensor as element 0 in argument 0, but got list**
Epoch = 1 [perplexity = x.xx, ml_loss = x.xx]: 0% 0/939 [00:01<?, ?it/s]
should we use torch.stack() ??? thanks in advance