Closed sgbaird closed 2 years ago
Works fine if I drop the SRUnet256
and replace it with the regular Unet
. Still trying to play around with parameters for the SRUnet256
to see if I can get it to under 6 GB at least for prototyping.
@lucidrains so SRUnet256 overrides many of the parameters? https://github.com/lucidrains/imagen-pytorch/blob/ccf848cfd1f8306e654112d61e95376180a67a90/imagen_pytorch/imagen_pytorch.py#L1419-L1431
Seems to be the case after looking at values via breakpoint.
unet2 = SRUnet256(
dim=4,
dim_mults=(1, 2),
num_resnet_blocks=(2, 2),
layer_attns=(False, False),
layer_cross_attns=(False, False),
attn_heads=1,
)
I don't think you should expect to be able to run inference on 6GB, much less train something
@samedii
Still trying to play around with parameters for the
SRUnet256
to see if I can get it to under 6 GB at least for prototyping.
Will probably do production runs via slurm submissions using my uni's hpc which will give me much more than 6 GB.
I'm hitting the same issue with 12GB, and have also run into this issue on an 40GB A100 that I used to verify. I have trained and inferred successfully on 8-16GB on versions up until 0.7. I jumped from 0.3 to 0.7 so I'll have to backtrack to find the last time this was working successfully on my equipment.
@lupinetine this is my first time using the library. If you figure out where the change occurred, would love to know! cc @lucidrains
Also, I realize max_batch_size=1
is a silly choice, was just trying to see if I could get anything to run without the OOM error https://github.com/lucidrains/imagen-pytorch/issues/24#issuecomment-1142306411
Might be fixed with some of the new releases.
should be better now that only one unet is loaded into memory at any given time
if it still OOMs, you should buy a better graphics card
Based on the README usage instructions, except with
max_batch_size=1
running on Windows:The OOM error occurs during the
SRUnet
(set a breakpoint and checked)I'm using an NVIDIA GeForce RTX 2060:
See also #12