autonomousvision / giraffe

This repository contains the code for the CVPR 2021 paper "GIRAFFE: Representing Scenes as Compositional Generative Neural Feature Fields"
https://m-niemeyer.github.io/project-pages/giraffe/index.html
MIT License
1.23k stars 160 forks source link

Fitting large number of objects in memory #12

Closed kjmillerCURIS closed 3 years ago

kjmillerCURIS commented 3 years ago

Thanks for sharing this awesome work, and congrats on winning best paper!

I'd like to train GIRAFFE on a custom dataset with up to 20+ objects per image, but I'm finding that a batch of 32 images won't fit into 11GB of GPU memory. For 64x54 resolution, I can render at most a batch of 18 images, and for 256x256 resolution, I can render at most a batch of 9 images. I haven't tried training yet, but I would expect it to take up at least as much memory as inference.

Do you think it would be safe to reduce the training batch size, or would that make the GAN training unstable at some point? Thanks.

m-niemeyer commented 3 years ago

Hi @kjmillerCURIS , thanks for your question! I think if you could train with a batch size of 16 or higher is always a good idea. ~12 also still works, but I would not go under 10/8. Regarding fitting lots of objects, I played around with ~10 objects (see Figure 11 of the supplementary) and for this I heavily reduced the number of hidden dimensions of the object / background feature fields (e.g. to 16/32). You could further reduce the number of sample points, and also the output feature size of the feature fields + hidden dimension of the neural renderer. Good luck with your research!

kjmillerCURIS commented 3 years ago

Thanks @m-niemeyer! I just realized there might be another option. I could split the batch into smaller parts and call forward() and backward() on each part, allowing pytorch to accumulate the gradients before taking a step. Can you think of any potential pitfalls to that approach? It looks like neither generator nor discriminator uses batch normalization, only instance normalization without any moving average. The only other problem might be the R1 gradient penalty of the discriminator, but I could just do that part without splitting the batch, as it only uses real images.

Is my understanding correct? Thanks.

m-niemeyer commented 3 years ago

Hi @kjmillerCURIS , that's correct, what you describe is often called gradient accumulation, and it's total fine to do this (see e.g. this post). The only downside is training takes longer. However, you only need to apply this to the generator update - remember, the discriminator only sees the 2D renderings, so there it doesn't matter how many objects you have in the scene.