EugenHotaj / pytorch-generative

Easy generative modeling in PyTorch.
MIT License
422 stars 68 forks source link

Training ImageGPT on 64 * 64 size images #22

Closed ysig closed 3 years ago

ysig commented 3 years ago

Hey,

I wrapped your implementation of ImageGPT for a project I am doing in collab. I have single channel images of size 64 with the intention of scaling it to 128. Running you code crashes (out of memory) even if I set batch_size to 2. It runs only with batch_size=1 and it doesn't use that much of memory around 3GB, while being extremely slow 60 hours for ~200Kpictures (20 s per picture). Is this normal or is there something wrong?

EugenHotaj commented 3 years ago

Hey @ysig,

Attention computation is very expensive and scales quadratically with the size of the input image (both in memory and time). For a 64 * 64 = 4096 image, the amount of memory needed is roughly 4096^2 * 32 * n_attention_heads * batch_size (where 32 is the number of bytes per float). So, for 1 attention layer with n_attention_heads=1 and batch_size=1, this consumes ~.5g of memory just for the attention activations. Our default ImageGPT architecture, which has 8 attention layers with n_attention_heads=4 each, would consume ~17g of memory per image just for the attention activations (so the the model parameters, optimizer variables, activations of other layers, etc, would need even more memory).

Scaling to 128 * 128 images on Colab is probably impossible with a single GPU.

To reduce the amount of memory, you could try a few things:

while being extremely slow 60 hours for ~200Kpictures (20 s per picture)

ImageGPT is a heavy model (as discussed above) so this is not surprising, especially if you only use batch_size=1. It might be possible to optimize things further but it's hard to say up front without knowing your training setup. That being said, one thing that will almost surely speed up your training is to swap out our MaskedAttention with PyTorch's MultiheadAttention. The latter is implemented as a CUDA kernel and is much faster.

Finally, I'm not sure what your use-case is, but you should also look into PixelSNAIL. It also uses attention to capture global context but it's a lot more scalable than ImageGPT.

ysig commented 3 years ago

Ok! Thanks a lot for your answer! Really appreciate it :)

Στις Δευ, 2 Νοε 2020 στις 4:49 μ.μ., ο/η Eugen Hotaj < notifications@github.com> έγραψε:

Hey @ysig https://github.com/ysig,

Attention computation is very expensive and scales quadratically with the size of the input image (both in memory and time). For a 64 64 = 4096 image, the amount of memory needed is roughly 4096^2 32 n_attention_heads batch_size (where 32 is the number of bytes per float). So, for 1 attention layer with n_attention_heads=1 and batch_size=1, this consumes ~.5g of memory just for the attention activations. Our default ImageGPT architecture https://github.com/EugenHotaj/pytorch-generative/blob/master/pytorch_generative/models/image_gpt.py#L59-L75, which has 8 attention layers with n_attention_heads=4 each, would consume ~17g of memory per image just for the attention activations (so the the model parameters, optimizer variables, activations of other layers, etc, would need even more memory).

Scaling to 128 * 128 images on Colab is probably impossible with a single GPU.

To reduce the amount of memory, you could try a few things:

while being extremely slow 60 hours for ~200Kpictures (20 s per picture)

ImageGPT is a heavy model (as discussed above) so this is not surprising, especially if you only use batch_size=1. It might be possible to optimize things further but it's hard to say up front without knowing your training setup. That being said, one thing that will almost surely speed up your training is to swap out our MaskedAttention https://github.com/EugenHotaj/pytorch-generative/blob/master/pytorch_generative/models/image_gpt.py#L38-L42 with PyTorch's MultiheadAttention https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html. The latter is implemented as a CUDA kernel and is much faster.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/EugenHotaj/pytorch-generative/issues/22#issuecomment-720517498, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGY7H2MXBZNASB5CPR52SL3SN3BHJANCNFSM4TGVZTTQ .