NVlabs / NVAE

The Official PyTorch Implementation of "NVAE: A Deep Hierarchical Variational Autoencoder" (NeurIPS 2020 spotlight paper)
https://arxiv.org/abs/2007.03898
Other
999 stars 163 forks source link

FFHQ model checkpoint leads to out of memory even on single image inference #24

Closed bhoov closed 3 years ago

bhoov commented 3 years ago

I am trying to play with the results of the model with the pretrained checkpoints, however, I continue to get out of memory errors when passing even a single image through the model for inference.

I load the state_dict and args of the FFHQ pretrained checkpoint into the model and pass in a single image of shape (1,3,256,256), dtype torch.float32, with values between [0, 1] (per the LMDB DataLoader) into the model's forward pass. I am using a Tesla V100 16GB, torch==1.6.0, CUDA 11.0, and Ubuntu 18.04.

I did a bit of debugging -- it looks like this Out of Memory error occurs when passing the image through the loop over the dec_tower.

Is there any reason, algorithmically, that the information for a single 256x256 image should consume more 16GB of GPU memory? Is there a memory leak somewhere?

Any insight and guidance on how to fix this would be appreciated.

bhoov commented 3 years ago

Figured it out. It looks like gradient tracking is on by default and calling model.eval() does not remove this. Running the forward pass in the with torch.no_grad(): context solved the issue. Closing.