Open seongminp opened 3 years ago
Hi,
log_p0
is a single distribution (over all possible tokens in the vocab) for each word in sequence, tgt
is one-hot vectors at its last dimension, thus in:
log_p0
is of shape (batch_size, seq_len, vocab_size), the first sum (sum(dim=2)
) actually indexes the distribution with tgt
as the index, to obtain a tensor with shape (batch_size, seq_len), which is the log likelihood of each token. The second sum (sum(dim=1)
) sums all the log likelihood on the seq_len
dimension, to return a tensor ll0
of shape (batch_size) which represents the log likelihood of each sentence.
Thank you so much for the explanation!
So when training the inference network q(latent y given observed x), do you not use teacher forcing during the forward pass that obtains latent y? (Meaning there is no pre-completed decoder input and latent y is created autoregressively by the inference network).
We want inference network to generate latent y in a similar way as the lm prior, but to obtain what the lm prior has to say about latent y, we need latent y in the first place. So I was wondering non-autoregressive forward pss of the inference network was at all possible.
I did not understand what to put as the decoder input in this case. I understand that we do give the transfer direction ‘c’ to the decoder.
Hello again. Thank you for sharing your work!
I have carefully read your paper (and looked through your code), but I fail to understand how LM priors are actually calculated. (Going from lstm logits -> distribution over vocab)
It seems to be calculated in this code snippet: https://github.com/cindyxinyiwang/deep-latent-sequence-model/blob/8a798582b1af5ef7f6ac4ca1f2138fd382a1cb06/src/model.py#L339
When you obtain the gumbel logits and log_softmax them, I guess they become probability distributions of input x. What I fail to grasp is the exact format of the distribution.
For every logit dimension (hidden dim), do we get a separate distribution? Or do we get a single distribution (over all possible tokens in the vocab) for each word in sequence? If so why is there a sum function..?
I’d appreciate it greatly if you could shed some light on this.
Thank you!