Nikolai10 / PerCo

PyTorch implementation of PerCo (Towards Image Compression with Perfect Realism at Ultra-Low Bitrates, ICLR 2024)
Apache License 2.0
34 stars 1 forks source link

How much GPU memory required? #6

Closed herok97 closed 1 month ago

herok97 commented 1 month ago

Thank you for your impressive work, including your repositories on perceptual/generative models. As a PhD student, I am very interested in your public contribution and find it highly motivating.

I have a question regarding the GPU memory required for your experiments, especially during training. I have 8 RTX 3090 GPUs, each with 24GB of memory.

I attempted to train the PerCo model on a single GPU with a batch size of 1, but I encountered an OOM (Out Of Memory) error. To reduce memory usage, I tried the following methods, but they were ineffective:

  1. Reduced the resolution of the sample noise:

    --resolution=256 or
    --resolution=128 or
    --resolution=64
  2. Enabled xformers: --enable_xformers_memory_efficient_attention

  3. Used mixed precision: mixed_precision: fp16

Could you please explain the memory required for training or provide any suggestions on how to resolve this issue? Thank you.

Nikolai10 commented 1 month ago

Hello @herok97,

thanks for your interest!

In general, finetuning Stable Diffusion (not PerCo (SD)) should be possible with 24GB VRAM according to: https://github.com/huggingface/diffusers/tree/main/examples/text_to_image (Section Hardware):

With gradient_checkpointing and mixed_precision it should be possible to fine tune the model on a single 24GB GPU. For higher batch_size and faster training it's better to use GPUs with >30GB memory.

Compared to this example, PerCo (SD) further incorporates BLIP 2 as image captioning model, which we use at runtime to obtain text descriptions. So one idea would be to pre-compute these text captions and load them as part of your dataset.

In our tutorial, we also provide an example with a single A100 GPU (40GB VRAM), which works fine. We have not fully tested smaller scenarios, including mixed_precision.

Kind regards, Nikolai

herok97 commented 1 month ago

Dear @Nikolai10,

Thank you so much for your thoughtful response and the valuable reference materials.👍👍👍 I realize that I need to become much more familiar with Stable Diffusion models and related techniques.

After testing, it seems that as you mentioned, due to BLIP-2, using only gradient_checkpointing and fp16 does not seem to avoid the OOM issue in my case, and the most effective method would be to generate the text description in advance and then proceed with the training.

Thank you once again.

Since this question is more of a request for help rather than an issue that needs to be resolved, I will close the issue for now.😊