schmidtdominik / LAPO

Code for the ICLR 2024 spotlight paper: "Learning to Act without Actions" (introducing Latent Action Policies)
https://arxiv.org/abs/2312.10812
77 stars 6 forks source link

[Q] quantization clarifications #3

Closed Howuhh closed 2 days ago

Howuhh commented 3 days ago

Hi! I noticed that the VQ-VAE used in LAPO is not quite standard, at least compared to the popular implementation from https://github.com/lucidrains/vector-quantize-pytorch and https://github.com/MishaLaskin/vqvae/blob/master/models/quantizer.py which is also quite close to the original implementation in tensorflow.

I'd love to know a bit more where this variation came from and how num_codebooks and num_discrete_latents should be interpreted? Paper mentions that Latent actions are 128-dimensional continuous vectors and are split and quantized into 8 discrete latents with 16-dimensional embeddings.

From such description, I would expect to have a separate codebook for each (with reshaping 128 to [8, 16]), or one for all (similar to how images are reshaped from H x W x C to H * W x C before quantization), but in the configs there are 2, which doesn't quite make sense to me...

schmidtdominik commented 2 days ago

Hi @Howuhh, great question! We split the 128-dim vector generated by the IDM into 8 chunks and apply VQ separately to each chunk. We initially used 8 codebooks (one for each chunk) but had some issues with codebook collapse, where the same code vector was always selected in the nearest neighbor lookup. We now instead use only 4 codebooks, such that each codebook is used in quantizing two of the latents, which fixed the collapse issue. We think this might be because if a codebook collapses for one of the two latents, it might still be useful for the other and so it continues to get updated through that branch until it "de-collapses". So the number 8 comes from num_codebooks(=2) * num_discrete_latents(=4). Note that we had limited compute for this, so with more HP tuning (maybe a larger batch size) you can likely find a more elegant solution for the collapse issue.

More generally, the intuition for this factorized action space is that in many games each transition might be the combination of several agents' actions and environmental effects (e.g. player shoots to the right + enemy 1 jumps + enemy 2 moves leftwards + screen fades to black + ...). This chunked-VQ approach is also similar to some more recent VQ methods like LFQ and FSQ that quantize per-dimension.

Hope this clarified things and please let me know if you have any other questions!

Howuhh commented 2 days ago

@schmidtdominik Thank you very much! Make sense now. For some reason on the new problem I was having trouble with my implementation of LAPO with the standard VQ-VAE, although FSQ worked out of the box. But want to reproduce as close as possible. I'll try it this way now.

schmidtdominik commented 1 day ago

@Howuhh Glad to hear! One thing to keep in mind is that VQ methods are usually used in image tokenization where the effective batch size for VQ is actually BS x H x W (e.g. 512 x 32 x 32 = 524288), and such a large BS can help prevent codebook collapse. In LAPO we only have one latent per frame, i.e. the effective batch size is several OOM lower, so it's expected that non-standard VQ hyperparams might be needed.