sara-nl / 3D-VQ-VAE-2

3D VQ-VAE-2 for high-resolution CT scan synthesis
https://scripties.uba.uva.nl/search?id=722710
36 stars 8 forks source link

Training PixelCNN unclear #5

Open Arksyd96 opened 1 month ago

Arksyd96 commented 1 month ago

Hi,

I'm using your implementation to generate MRIs. I have trained a VQ-VAE to reconstruct 3D MRIs, but I am unsure about which vectors to use for training the PixelCNN for sampling.

I attempted to understand your LMDB implementation, but it would take me a significant amount of time to fully grasp it. I'm not clear on what exactly is being stored in the LMDB database.

Given that the VQ-VAE encoder outputs multiple quantization vectors (one for each encoding block), what should be the specific input for the PixelCNN?

x = torch.randn(4, 3, 128, 128, 64).to('cuda')
decoded, (commitment_loss, quantizations, encoding_idx) = vqvae(x)

I think i'll have to modify the LMDB data module part.

Thank you!

robogast commented 1 month ago

Hi! It has been a while since I've worked on this project, so my memory is not too sharp. I'll try to see what I can do to help.

As far as I can see/remember, the input to the PixelCNN is a list of 3 dimensional one-hot encoded matrices (tensors), see how I unpack them in the PixelCNN:

https://github.com/sara-nl/3D-VQ-VAE-2/blob/0b2148f8fe344c81044bdee3ee83efa9d4cf4934/pixel_model/pixelcnn.py#L106-L120

The whole pickling/txn context etc is just fluff needed for LMDB to work.

The reason I use LMDB was that at the time it was the only database implementation available to support both memmapped arrays and concurrent reads (which is important for computational efficiency when running multi-node, which I did for sampling the full 512x512x128 volumes)

As said I'm not entirely up-to-date on these kinds of workloads anymore, but two thoughts:

If you have more questions let me know.

Robert Jan