Closed naveen-kinnal closed 2 years ago
I think we use gumbel_softmax the same way across all tasks to propagate gradients from the language model loss. For reference, gumbel_softmax is used here and is enabled in Yelp sentiment transfer:
Thank you. But the above line doesn't get reached at all. In this below code, the self.hparams.gumbel_softmax
is not set to True and hence it will reach the if condition (get_translations), but not the get_soft_translations.
if self.hparams.bt:
if eval or (not self.hparams.gumbel_softmax):
with torch.no_grad():
x_trans, x_trans_mask, x_trans_len, index = self.get_translations(x_train, x_mask, x_len, y_sampled, y_sampled_mask, y_sampled_len, temperature)
index = torch.tensor(index.copy(), dtype=torch.long, requires_grad=False, device=self.hparams.device)
else:
# with torch.no_grad():
lm_flag = True
x_trans, x_trans_mask, x_trans_len, index, org_index, neg_entropy = self.get_soft_translations(x_train, x_mask, x_len,
y_sampled, y_sampled_mask, y_sampled_len)
trans_length = sum(x_trans_len)
else:
trans_length = 0.
The back-translation loss does not use Gumbel-softmax as explained in the paper, only the LM loss uses Gumbel-softmax here:
I think the param 'lm' is also not set to true for the YELP style transfer. Could you please let me know in which file we set the 'lm' to true?
But I don't see "--gs_soft" set. As a result, x_trans_lm is a hard generation, then how can the gradients of log_prior backprop?
Hi,
When --gs_soft
is not set, gumbel_softmax will output discretized one-hot vectors, but will be differentiated as if it is the soft sample in autograd. Maybe also note that our language model multiplies the embedding weights with those one-hot vectors to obtain word embeddings, thus the entire process is differentiable:
Hello. I want to know why is the gumbel_softmax not set to True in the model params for the YELP sentiment transfer task? Is it because we have a pre-trained language model and hence do want to compute the KL loss with respect to the gumbel softmax log prior? Thanks.