lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
8.11k stars 768 forks source link

Absurd memory usage (>512GB?!) when trying to train 256x256->1024x1024 network #300

Closed jameshball closed 1 year ago

jameshball commented 1 year ago

I'm currently trying to train a 256x256->1024x1024 model and was experiencing memory issues when using the Imagen SR_1024 Unet so I massively scaled down to get somewhere that worked, and then was aiming to increase parameters until it was stable.

I haven't had any issues when trying the 64x64 and 64x64->256x256, even when dim=512 on the base model.

I am using these parameters for the Unet just for testing (obviously they won't perform well):

Unet(
    dim=8,
    cond_dim=512,
    dim_mults=(1, 2),
    num_resnet_blocks=2,
    memory_efficient=True,
    layer_attns=(False, True),
    layer_cross_attns=(False, True),
)

And when trying to run with the 64->64 and 64->256 Unets just being NullUnets I get this error:

Traceback (most recent call last):
  File "train.py", line 184, in <module>
    main()
  File "train.py", line 150, in main
    loss = trainer.train_step(unet_number=args.unet_number, max_batch_size=4)
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/trainer.py", line 601, in train_step
    loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs)
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/trainer.py", line 619, in step_with_dl_iter
    loss = self.forward(**{**kwargs, **model_input})
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/trainer.py", line 135, in inner
    out = fn(model, *args, **kwargs)
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/trainer.py", line 972, in forward
    loss = self.imagen(*chunked_args, unet = self.unet_being_trained, unet_number = unet_number, **chunked_kwargs)
  File "~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "<@beartype(imagen_pytorch.imagen_pytorch.Imagen.forward) at 0x7fb425d24b80>", line 62, in forward
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2537, in forward
    return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, p2_loss_weight_gamma = p2_loss_weight_gamma, random_crop_size = random_crop_size, **kwargs)
  File "<@beartype(imagen_pytorch.imagen_pytorch.Imagen.p_losses) at 0x7fb425d24a60>", line 33, in p_losses
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2426, in p_losses
    pred = unet.forward(
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 1645, in forward
    x = attn_block(x, c)
  File "~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 982, in forward
    x = attn(x, context = context) + x
  File "~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 527, in forward
    sim = einsum('b h i d, b j d -> b h i j', q, k) * self.cosine_sim_scale
  File "~/.local/lib/python3.8/site-packages/torch/functional.py", line 378, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 512.32 GiB (GPU 0; 47.32 GiB total capacity; 5.62 GiB already allocated; 40.35 GiB free; 5.89 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

On the default Imagen SR_1024 I get this OOM which looks more normal, but it's still surprising:

Traceback (most recent call last):
  File "train.py", line 176, in <module>
    main()
  File "train.py", line 142, in main
    loss = trainer.train_step(unet_number=args.unet_number, max_batch_size=4)
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/trainer.py", line 601, in train_step
    loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs)
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/trainer.py", line 619, in step_with_dl_iter
    loss = self.forward(**{**kwargs, **model_input})
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/trainer.py", line 135, in inner
    out = fn(model, *args, **kwargs)
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/trainer.py", line 972, in forward
    loss = self.imagen(*chunked_args, unet = self.unet_being_trained, unet_number = unet_number, **chunked_kwargs)
  File "~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "<@beartype(imagen_pytorch.imagen_pytorch.Imagen.forward) at 0x7f96d45c9af0>", line 62, in forward
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2537, in forward
    return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, p2_loss_weight_gamma = p2_loss_weight_gamma, random_crop_size = random_crop_size, **kwargs)
  File "<@beartype(imagen_pytorch.imagen_pytorch.Imagen.p_losses) at 0x7f96d45c99d0>", line 33, in p_losses
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2426, in p_losses
    pred = unet.forward(
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 1642, in forward
    x = resnet_block(x, t)
  File "~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 717, in forward
    h = self.block2(h, scale_shift = scale_shift)
  File "~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "~/.local/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 654, in forward
    return self.project(x)
  File "~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "~/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "~/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB (GPU 0; 47.32 GiB total capacity; 46.02 GiB already allocated; 185.69 MiB free; 46.06 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Do you have any idea why this might be happening? This seems very surprising on a Quadro RTX 8000! What's the expected memory usage here?

jameshball commented 1 year ago

When changing it to a 256x256->512x512 model, it used around 46gb with the SRUnet1024 and then after saving a checkpoint and starting again it OOMed :( is this memory usage normal? 64x64->256x256 was around 17gb

lucidrains commented 1 year ago

@jameshball hey James, it is because you have self-attention turned on, and at an early stage in the unet (you only have two stages)

you can turn off the self attention altogether for the super-resoluting nets, and add more stages and place your cross attention on the innermost stage. those two changes should help with memory

jameshball commented 1 year ago

@lucidrains Thanks for the helpful reply and explanation! Doesn't the SRUnet1024 already do that since layer_attns = False and layer_cross_attns = (False, False, False, True) ?

Even using SRUnet1024 and scaling back some parameters like the dim_mults I get an OOM.

I still think it's bizarre that it's trying to allocate over 512GiB of memory at once in the first example!

lucidrains commented 1 year ago

@jameshball yea, 512 and 1024 are difficult to train

in the paper, what they did was to train on random crops (the feature is built-in here). Because the unet is all convolution, it should generalize to higher resolutions

lucidrains commented 1 year ago

@jameshball i forget the details, but i believe they did random crops of 256x256 (but you should double check that, since it's been a while since i read this paper)

jameshball commented 1 year ago

Ah, that's really clever! 256x256 is right, yep. It seems to be extremely low memory usage now so it looks perfect. Thanks so much for your help! Gonna be a long one to train I think...

lucidrains commented 1 year ago

@jameshball glad you have it working! i'll get back to this repository soon and add some better error messages

or write up a section in the readme with what we discussed above

jameshball commented 1 year ago

I will be doing a lot of the more intense stuff over the coming months, like this and also potentially extending inpainting as I alluded to in #290

Yep, using this for my master's thesis to generate histopathology images conditioned on clinical parameters, and potentially follow up with a paper, so that's the goal indeed. I'm very indebted to you for your hard work!

lucidrains commented 1 year ago

@jameshball that is so cool! these generative models will definitely be a boon to medical education