jxmorris12 / categorical-vae

Categorical Variational Auto-encoders in PyTorch
20 stars 7 forks source link

Substitution of softmax function in torch? #1

Closed tessavdheiden closed 2 years ago

tessavdheiden commented 2 years ago

Hi!

Great repo and very well documented!

I was wondering, you have implemented your own Gumbel Softmax function, with gradient. Can you substitute it with the Pytorch Gumbel Softmax function?

Best, Tessa

jxmorris12 commented 2 years ago

@tessavdheiden Thanks! but why?

tessavdheiden commented 2 years ago

Because you can avoid the 35 (44-9) lines, with F.gumbel_softmax(x, dim=-1, tau=tau). So, it will make your code slimmer.

jxmorris12 commented 2 years ago

Oh yeah, thanks. But I wanted to implement it myself, that was kind of the point.