sosuperic / MeanSum

Other
112 stars 51 forks source link

Possible bug when setting k > 1 & Gumbel_hard = True #13

Open hadyelsahar opened 5 years ago

hadyelsahar commented 5 years ago

Thanks for the interesting work and code

I was trying to get my head around the code and I couldn't understand something:

When training the mlstm model If we try the following set of parameters:

In the line mlstm #L291 The logits_to_prob function will return a strict one hot vector according to the Torch gumbel softmax implementation F.gumbel_softmax

Afterwards, this prob vector is sent to prob_to_vocab_id method which is supposed to apply either torch.top_k (beam search) or torch.multinomial (top k sampling).

Implementation wise this shouldn't show any errors in beam search because of the torch.topk function ability to handle draws, however, the top k you get aren't the actual top k probabilities e.g.

image

But if you try to sample multinomial from 1 hot vector where K > 1 you get a runtime error: image

Am I missing something here?