sara-nl / 3D-VQ-VAE-2

3D VQ-VAE-2 for high-resolution CT scan synthesis
https://scripties.uba.uva.nl/search?id=722710
41 stars 8 forks source link

Clarifications on Decoder and Embeddings #3

Open aksg87 opened 2 years ago

aksg87 commented 2 years ago

@robogast

Happy to report I was able to train a VQ-VAE using a dataset. Very cool to see - and kudos for the nice Tensorboard outputs you have in place! 😎

  1. Do you have any suggestions or code for randomly sampling from the decoder in a generative fashion?

  2. Also, If you have a summary of these files and their purpose, that would be very helpful. I would be happy to do a PR with some comments in the repository if that would be helpful.

Questions on: calc_ssim_from_checkpoint.py # does this calculate SSIM across the dataset ❓ decode_embeddings.py # Specifications for db_path ❓ extract_embeddings.py # Does this save embedding to disk ❓

Ran successfully: plot_from_checkpoint.py # plots a forward pass from a random sample ✅ train.py # trains a model ✅

Much appreciated! -Akshay

robogast commented 2 years ago

Hi! Answers to your questions:

  1. You cannot sample the decoder directly, you need to train an autoregressive prior (i.e. pixelcnn, pixelsnail, ViT, ..., maybe using a discrete denoising model would be cool...) on the embeddings obtained by putting your dataset through the encoder. You then sample your autoregressive model for embeddings, and put those embeddings through the decoder. See the original VQ-VAE paper: https://arxiv.org/pdf/1711.00937.pdf
    • calc_ssim_from_checkpoint -> I simply had not added SSIM as a metric to tensorboard yet when I wrote this script, so this file can be ignored (or removed) now.
    • decode_embeddings.py -> the db_path are the generated embeddings by your autoregressive model, so you don't have them right now.
    • extract_embeddings.py -> yes, this file in principle takes your model + dataset and created the embeddings which should be used as training input for your autoregressive model.
    • As a general note, these three files are scripts and not intended as library files, and thus should be treated as such (i.e. low quality control, hardcoding a lot of stuff).

Nice to see that you're progressing :)

aksg87 commented 2 years ago

@robogast - Appreciate all of the information! Need to review the paper again :)

I look forward to trying the other scripts and posting how things go!

aksg87 commented 2 years ago

Hi @robogast

Your comments make much more sense now after reviewing the literature further :)

This is a nice overview from AI Epiphany!

https://www.youtube.com/watch?v=VZFVUrYcig0&t=1736s

aksg87 commented 2 years ago

Hi @robogast

I was trying to better understand encoding_idx. My understanding is that this is the last item in each of the 3 bottle neck layers? Curious why we throw the rest of the information away?

Thanks in advance! -Akshay

def extract_samples(model, dataloader):
    model.eval()
    model.to(GPU)

    with torch.no_grad():
        for sample, _ in dataloader:
            sample = sample.to(GPU)
            *_, encoding_idx = zip(*model.encode(sample))
            yield encoding_idx