AmigoLab / SynthAnatomy

VQ-VAE + Transformer based synthesis of 3D anatomical imaging data
https://link.springer.com/chapter/10.1007/978-3-031-16980-9_7
Apache License 2.0
57 stars 6 forks source link

Question about the VQVAE embedding indices #2

Closed tommybotch closed 1 year ago

tommybotch commented 1 year ago

Hello! Thank you for sharing your work. I ran the VQVAE model with the following parameters:

--no_levels=4
--downsample_parameters='((4, 2, 1, 1), (4, 2, 1, 1), (4, 2, 1, 1), (4, 2, 1, 1))'
--upsample_parameters='((4, 2, 1, 0, 1), (4, 2, 1, 0, 1), (4, 2, 1, 0, 1), (4, 2, 1, 0, 1))'
--no_res_layers=3
--no_channels=256
--num_embeddings='(2048,)' 
--embedding_dim='(256,)' 

My input is shape (1,1,96,128,96) and the output of the encoder is shape (1, 256, 6, 8, 6). Given 256 channels, I would expect to receive 256 embedding indices from the quantizer (expected shape (256, 6, 8, 6)). However, the output of the function index_quantize yields embedding indices of shape (1, 6, 8, 6).

danieltudosiu commented 1 year ago

Hi @tlasmanbotch,

Thank you for your interest in our work.

Regarding your questions, let's outline the structure of a VQ-VAE. A VQ-VAE is formed by an Encoder that takes the input image [B,C,H,W,D] and projects it into the encoder space of size [B,c,h,w,d]. Then the quantizer takes the encoder space [B,c,h,w,d] and projects it first to an index quantized [B,1,h,w,d] and then back into quantized [B,c,h,w,d] by mapping each vector of the encoder space to the closest element in its codebook.

  1. Could you help me understand why there is one set of indices for 256 features?
    • The index that you received from the index_quantize method is the index of the codebook element to which the vector of the encoder space representation is mapped to. This index-based representation is used for training transformer-style networks.
  2. I see in the paper you say you get up to 1400 tokens for a single input. What parameter allows for that representation here?
    • The 1400 tokens representation is calculated based on the input size and network structure. In our work, we have an input of shape [160, 224, 160] which after four downsamples it gets to [10, 14, 10] and after flattening (cuz transformers work on sequences) it ends up being a sequence of 1400 tokens. In your case you ended up with an index-based representation of size [6,8,6] which will become a sequence of size 288.

If you have any other questions please let me know. If you want we can set up a meeting and clarify everything. Otherwise if all is good please close the issue.

Cheers,

Dan

tommybotch commented 1 year ago

Hi Dan,

Thank you for your detailed reply - I greatly appreciate the help! This all makes sense and apologies for my misunderstanding. I am closing this issue.

Best, Tommy