huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.32k stars 5.42k forks source link

Training InstructPix2Pix SDXL has OOM issue #4985

Closed shugerdou closed 1 year ago

shugerdou commented 1 year ago

Describe the bug

Hello,

I used a 40GB A100 machine to train InstructPix2Pix SDXL using the toy data following the instruction.

However, even though I reduce the batch size to 1, I still have the OOM issue. Wondering if anyone has observed the same issue? Thanks.

Reproduction

accelerate launch train_instruct_pix2pix_sdxl.py \
    --pretrained_model_name_or_path=$MODEL_NAME \
    --dataset_name=$DATASET_ID \
    --enable_xformers_memory_efficient_attention \
    --resolution=256 --random_flip \
    --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
    --max_train_steps=15000 \
    --checkpointing_steps=5000 --checkpoints_total_limit=1 \
    --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
    --conditioning_dropout_prob=0.05 \
    --seed=42 \
    --push_to_hub

Logs

No response

System Info

Who can help?

No response

thibble commented 1 year ago

You can log free GPU memory throughout your script to understand what exactly is taking up too much memory. I'm having similar issues and tried to run accelerate config to configure some optimization (DeepSpeed, PyTorch FSDP) but to no success until now.

@sayakpaul somehow pulled it off here. On 8x80GB GPUs but also using 768px resolution. As far as I see in his code there's no special cpu offloading, or even pre-computing prompt embeddings.

@sayakpaul do you mind sharing how you made it work?

beyzacevik commented 1 year ago

I have the same issue. I use a 60GB dataset with 180 GB RAM but the processes gets killed when generating splits.

thibble commented 1 year ago

I have the same issue. I use a 60GB dataset with 180 GB RAM but the processes gets killed when generating splits.

I've had that happen as well but pre-computing embeddings outside of the training loop fixed it for me. I can send you the code snippets. However, the specifics will depend on your dataset. Are your prompts standardized or do they differ for each example? It would also be helpful to know the size of your dataset and the resolution of your images

sayakpaul commented 1 year ago

Thanks for your interest, folks!

We used an efficient dataloader based on webdataset which seems to be the only major difference here. FWIW, that script runs seamlessly with 1024x1024 too.

thibble commented 1 year ago

Thanks for your interest, folks!

We used an efficient dataloader based on webdataset which seems to be the only major difference here. FWIW, that script runs seamlessly with 1024x1024 too.

Interesting... The dataset loading should not affect GPU RAM, right? I can train on one GPU but run OOM when training on multiple

sayakpaul commented 1 year ago

The way we cache the embeddings might affect that behaviour, yes. Since caching is a part of the data preparation process in the official example.

thibble commented 1 year ago

The way we cache the embeddings might affect that behaviour, yes. Since caching is a part of the data preparation process in the official example.

Interesting, thank you! I'll take a closer look at the embedding caching part. Do you have pointers to articles or explanations of why multi-gpu training seems to have a memory peak on GPU 1?

So e.g. a model training needs 40GB GPU VRAM on one GPU but when parallelized, GPU 1 needs 50GB while the others only need 40GB.

shugerdou commented 1 year ago

@thibble Thank you for your great discussions. Although I can see which part of the model takes a lot of GPUs. I am curious the reason and the solution. I have been using SDXL and InstructPix2Pix but have never seen OOM issue even with a larger batch size.

thibble commented 1 year ago

@thibble Thank you for your great discussions. Although I can see which part of the model takes a lot of GPUs. I am curious the reason and the solution. I have been using SDXL and InstructPix2Pix but have never seen OOM issue even with a larger batch size.

Can you elaborate on how you've been using SDXL and InstructPix2Pix? Do you mean you have trained on the same configuration without OOM issues before? (256px, 40GB A100 machine)

shugerdou commented 1 year ago

@thibble Sorry that I did not explain clearly. I mean I have used the original instructpix2pix code in the same machine, which used SD1.5 as backbone and did not use huggingface library. I have not trained instructpix2pix using huggingface yet.

shugerdou commented 1 year ago

Thanks for your interest, folks!

We used an efficient dataloader based on webdataset which seems to be the only major difference here. FWIW, that script runs seamlessly with 1024x1024 too.

Thanks @sayakpaul. I am logging the error. May I know if you have seen it? Or could you please share the dataset that I can use a webdatset dataloader? I can try to debug if the data loading is the issue.

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 50.00 MiB (GPU 13; 39.59 GiB total capacity; 36.61 GiB already allocated; 44.12 MiB free; 37.07 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
    optimizer.step()
github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

HaozheZhao commented 5 months ago

I have also run into the same bug, even training the sdxl with 256256 images with batch size=1, it still leads to OOM even for the 880G A100 gpus. Is there any suggestions.

sayakpaul commented 5 months ago

Have you tried enabling gradient accumulation and gradient checkpointing?

HaozheZhao commented 5 months ago

Have you tried enabling gradient accumulation and gradient checkpointing?

Thanks! I've realized that factors such as batch size, image size, gradient accumulation, and gradient checkpointing don't significantly impact the occurrence of Out of Memory (OOM) errors. The critical factor appears to be enabling enable_xformers_memory_efficient_attention.

jiangyuhangcn commented 4 months ago

Have you tried enabling gradient accumulation and gradient checkpointing?

Thanks! I've realized that factors such as batch size, image size, gradient accumulation, and gradient checkpointing don't significantly impact the occurrence of Out of Memory (OOM) errors. The critical factor appears to be enabling enable_xformers_memory_efficient_attention.

Hi! May i ask how much memory you are using