Closed Warvito closed 1 year ago
Currently, I guess the inferer is not very usefull to perform the forward pass of the VQ-VAE+transformer during training, since it does not return the target sequence. For example, in https://github.com/Project-MONAI/GenerativeModels/blob/589d79ec15002259fe1c569c500f52472e501baa/tutorials/generative/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.py#L449
if we try to use the inferer with something like
logits = inferer( inputs=images, vqvae_model=vqvae_model, transformer_model=transformer_model, ordering=ordering, ).transpose(1, 2)
we will not have access to quantizations_target to compute the CE loss in loss = ce_loss(logits, quantizations_target)
quantizations_target
loss = ce_loss(logits, quantizations_target)
Hi @Warvito
There is now a return_latent option that you can use to get the target sequence back:
return_latent
https://github.com/Project-MONAI/GenerativeModels/blob/589d79ec15002259fe1c569c500f52472e501baa/generative/inferers/inferer.py#L432
Sorry, can't believe that I missed it. xD Thanks, Mark!
Currently, I guess the inferer is not very usefull to perform the forward pass of the VQ-VAE+transformer during training, since it does not return the target sequence. For example, in https://github.com/Project-MONAI/GenerativeModels/blob/589d79ec15002259fe1c569c500f52472e501baa/tutorials/generative/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.py#L449
if we try to use the inferer with something like
we will not have access to
quantizations_target
to compute the CE loss inloss = ce_loss(logits, quantizations_target)