Project-MONAI / GenerativeModels

MONAI Generative Models makes it easy to train, evaluate, and deploy generative models and related applications
Apache License 2.0
594 stars 82 forks source link

Fix VQVAETransformerInferer __call__ for training #290

Closed Warvito closed 1 year ago

Warvito commented 1 year ago

Currently, I guess the inferer is not very usefull to perform the forward pass of the VQ-VAE+transformer during training, since it does not return the target sequence. For example, in https://github.com/Project-MONAI/GenerativeModels/blob/589d79ec15002259fe1c569c500f52472e501baa/tutorials/generative/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.py#L449

if we try to use the inferer with something like

        logits = inferer(
            inputs=images,
            vqvae_model=vqvae_model,
            transformer_model=transformer_model,
            ordering=ordering,
        ).transpose(1, 2)

we will not have access to quantizations_target to compute the CE loss in loss = ce_loss(logits, quantizations_target)

marksgraham commented 1 year ago

Hi @Warvito

There is now a return_latent option that you can use to get the target sequence back:

https://github.com/Project-MONAI/GenerativeModels/blob/589d79ec15002259fe1c569c500f52472e501baa/generative/inferers/inferer.py#L432

Warvito commented 1 year ago

Sorry, can't believe that I missed it. xD Thanks, Mark!