lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.12k stars 179 forks source link

Gumbel max trick does not seem to make sense in here #108

Open ivallesp opened 4 months ago

ivallesp commented 4 months ago

Hi all,

I want to ask a question regarding some concerns I got looking at the usage of the gumbel_sample method when reinmax=False.

https://github.com/lucidrains/vector-quantize-pytorch/blob/6102e37efefefb673ebc8bec3abb02d5030dd933/vector_quantize_pytorch/vector_quantize_pytorch.py#L472-L472

First, this sampling technique is mathematically equivalent to sample from the categorical distribution, Gumbel is doing nothing here (just sampling), and the argmax makes the operation non differentiable (I know we apply STE later).

https://github.com/lucidrains/vector-quantize-pytorch/blob/6102e37efefefb673ebc8bec3abb02d5030dd933/vector_quantize_pytorch/vector_quantize_pytorch.py#L72-L77

Additionally, the logits are the codebook distances (dist in the first snippet above). It's an always positive variable, which means that it's going to be biased because it's bounded at zero. There are no gradients flowing from the sampling operation backwards (because it is not a Gumbel softmax, but a Gumbel max) hence the logits magnitude never gets altered to improve the sampling.

It seems to me that this is just takes a hidden variable (the distance matrix) normalizes it given an arbitrary temperature parameter and samples from it, adding biased noise to the straight-through relaxation... What am I missing?