lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
8.11k stars 768 forks source link

Partial ignore time causes increased memory usage #320

Open HReynaud opened 1 year ago

HReynaud commented 1 year ago

Hello,

I have been training models with ignore_time=True randomly activated at every step, and this seems to significantly increase the memory usage. For example on a dummy task: If I train with a fixed ignore_time=True I see ~4.5Go of memory usage. If I train with a fixed ignore_time=False I see ~6.3Go of memory usage (which makes sense since more layers are used) If I train with ignore_time=False and ignore_time=True with a 50% chance at each step, I see 8.4Go of memory usage.

This only happens when using accelerate / multi-gpu, never with single-gpu training, and causes OOM errors.

I tried setting the value of ignore_time to the parity of the current training step so when using accelerate, all instances would be training with the same value for ignore_time but it did not help.

The problem does not happen when memory_efficient=True is set.

A side note: The new attention stabilizing trick seems to increase memory usage significantly (~30%), maybe it could be deactivated with a flag ? It would also help with backward compatibility for trained models.

axel588 commented 1 year ago

@HReynaud What training time do you have for training on 10 000 images on text to image ?

HReynaud commented 1 year ago

@axel588 Your question is very vague and my answer will not help much.

I am training on ~7k videos on 8xA100 and the videos start looking "okay" after 20 hours. Usually I aim for 48h though. I do NOT train with text, I use other ways of conditioning. But my dataset is very niche, and my observation will likely not transfer to other datasets.