Project-MONAI / GenerativeModels

MONAI Generative Models makes it easy to train, evaluate, and deploy generative models and related applications
Apache License 2.0
598 stars 85 forks source link

Stage 2 VQ-GAN Tutorial - Transformer training next-token prediction #408

Closed HieuPhan33 closed 1 year ago

HieuPhan33 commented 1 year ago

Hi,

Thanks for sharing the code for training VQ-GAN.

I see that the tutorial (in tutorials/generative/2d_vqgan) only trains Stage 1 (Image reconstruction - latent space learning).

VQ-GAN also has a second stage which trains an autoregressive Transformer to predict next-token and synthesize new images.

Could you please share the code to train the Transformer?

Plus, VQ-GAN Transformer training also supports conditional image synthesis by prepending the tokens/codebooks of conditional images (like MR image to synthesize CT, Zebra image to synthesize Horse, ...). Do you plan to consider this feature in monai?

marksgraham commented 1 year ago

Hi,

We've got a tutorial here that covers both trainign the stage 1 VQ-GAN and the stage-2 transformer.

We currently support conditioning the transformer using cross-attention by passing the context parameter. It would be possible to do the version you've described, with prepending tokens, but you'd have to do it manuualy.

HieuPhan33 commented 1 year ago

Thanks for your reply,

I have one more question, does the current Transformer-training code support 3D image with spatial_dim=3?

marksgraham commented 1 year ago

Yes - you set spatial_dims=3 in the Ordering object so it knows it is collapsing a 3D image to a 1D sequence (and also you need to train a 3D VQ-VAE, too, if you're training a VQ-VAE + Transformer model).