Closed jameshball closed 1 year ago
When changing it to a 256x256->512x512 model, it used around 46gb with the SRUnet1024 and then after saving a checkpoint and starting again it OOMed :( is this memory usage normal? 64x64->256x256 was around 17gb
@jameshball hey James, it is because you have self-attention turned on, and at an early stage in the unet (you only have two stages)
you can turn off the self attention altogether for the super-resoluting nets, and add more stages and place your cross attention on the innermost stage. those two changes should help with memory
@lucidrains Thanks for the helpful reply and explanation! Doesn't the SRUnet1024 already do that since layer_attns = False
and layer_cross_attns = (False, False, False, True)
?
Even using SRUnet1024 and scaling back some parameters like the dim_mults
I get an OOM.
I still think it's bizarre that it's trying to allocate over 512GiB of memory at once in the first example!
@jameshball yea, 512 and 1024 are difficult to train
in the paper, what they did was to train on random crops (the feature is built-in here). Because the unet is all convolution, it should generalize to higher resolutions
@jameshball i forget the details, but i believe they did random crops of 256x256 (but you should double check that, since it's been a while since i read this paper)
Ah, that's really clever! 256x256 is right, yep. It seems to be extremely low memory usage now so it looks perfect. Thanks so much for your help! Gonna be a long one to train I think...
@jameshball glad you have it working! i'll get back to this repository soon and add some better error messages
or write up a section in the readme with what we discussed above
I will be doing a lot of the more intense stuff over the coming months, like this and also potentially extending inpainting as I alluded to in #290
Yep, using this for my master's thesis to generate histopathology images conditioned on clinical parameters, and potentially follow up with a paper, so that's the goal indeed. I'm very indebted to you for your hard work!
@jameshball that is so cool! these generative models will definitely be a boon to medical education
I'm currently trying to train a 256x256->1024x1024 model and was experiencing memory issues when using the Imagen SR_1024 Unet so I massively scaled down to get somewhere that worked, and then was aiming to increase parameters until it was stable.
I haven't had any issues when trying the 64x64 and 64x64->256x256, even when dim=512 on the base model.
I am using these parameters for the Unet just for testing (obviously they won't perform well):
And when trying to run with the 64->64 and 64->256 Unets just being NullUnets I get this error:
On the default Imagen SR_1024 I get this OOM which looks more normal, but it's still surprising:
Do you have any idea why this might be happening? This seems very surprising on a Quadro RTX 8000! What's the expected memory usage here?