hamadichihaoui / BIRD

This is the official implementation of "Blind Image Restoration via Fast Diffusion Inversion"
212 stars 18 forks source link

Inquiry on Memory Usage Issues with Diffusion Model Optimization #2

Open Breeze-Zero opened 1 month ago

Breeze-Zero commented 1 month ago

Hello, I wanted to express my gratitude for your work; it's been instrumental in my current project. However, I've encountered some confusion that I'm hoping you can shed light on.

I understand that in your approach, you freeze the weights of the diffusion model and only optimize the input random latent. While I appreciate the elegance of this method, I've noticed an issue with GPU memory usage that I believe may be related to the backward process.

Since the backward process cannot utilize with torch.no_grad(), the gradients of the latent variables are still being recorded and occupy GPU memory. As a result, the GPU memory usage increases with each step of the backward process. In my case, when optimizing a pre-trained 64-channel model on images of size (448, 168), the memory usage reaches up to 30GB.

I'm not certain if you've encountered this situation before. It's possible that because I'm manually migrating the code, there might be a step I've overlooked that's causing this issue.

hamadichihaoui commented 3 weeks ago

@Breeze-Zero thanks for your interest in our work and sorry for the late reply. The memory that you need depends on the parameters number. I am using by default an image of size 256x256x3 which requires around 1.2 GB. In your case it is 64x448x168 which around 20 times more parameters so around 20 times more memory. There is one technique called gradient checkpointing, I think probably it can be used to decrease the memory requirement. I plan to investigate it.