I was trying to get my head around the code and I couldn't understand something:
When training the mlstm model If we try the following set of parameters:
gumbel_hard = true
sampling method = "greedy" or "sample"
k > 1
In the line mlstm #L291
The logits_to_prob function will return a strict one hot vector according to the Torch gumbel softmax implementation F.gumbel_softmax
Afterwards, this prob vector is sent to prob_to_vocab_id method which is supposed to apply either torch.top_k (beam search) or torch.multinomial (top k sampling).
Implementation wise this shouldn't show any errors in beam search because of the torch.topk function ability to handle draws, however, the top k you get aren't the actual top k probabilities e.g.
But if you try to sample multinomial from 1 hot vector where K > 1 you get a runtime error:
Thanks for the interesting work and code
I was trying to get my head around the code and I couldn't understand something:
When training the
mlstm
model If we try the following set of parameters:In the line mlstm #L291 The
logits_to_prob
function will return a strict one hot vector according to the Torch gumbel softmax implementationF.gumbel_softmax
Afterwards, this prob vector is sent to
prob_to_vocab_id
method which is supposed to apply eithertorch.top_k
(beam search) ortorch.multinomial
(top k sampling).Implementation wise this shouldn't show any errors in beam search because of the
torch.topk
function ability to handle draws, however, the top k you get aren't the actual top k probabilities e.g.But if you try to sample
multinomial
from 1 hot vector where K > 1 you get a runtime error:Am I missing something here?