huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
23.96k stars 4.93k forks source link

Progressive Memory Increase on loading lora weights #8532

Open danishboy000 opened 3 weeks ago

danishboy000 commented 3 weeks ago

Describe the bug

When invoking pipeline.load_lora_weights(path), an incremental memory increase is observed over multiple iterations, indicating a potential memory management issue. Although pipeline.unload_lora_weights() is called following each load operation, it does not consistently return the memory usage to the pre-load state, resulting in gradual memory build-up. As indicated in the attached logs that shows the increase in memory utilization on iterations using memory_profiler.

Reproduction

from diffusers import DiffusionPipeline import torch from memory_profiler import profile

pipeline = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda")

@profile def lora_load(path): pipeline.load_lora_weights(path) pipeline.unload_lora_weights()

for i in range(10): print(f"Iteration {i}:") path = "lora/test.safetensors" lora_load(path)

Logs

Iteration 0:
Filename: /home/ubuntu/lora-test/lora.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    28   6730.1 MiB   6730.1 MiB           1   @profile
    29                                         def lora_load(path):
    30   6742.7 MiB     12.6 MiB           1           pipeline.load_lora_weights(path)
    31   6739.7 MiB     -3.0 MiB           1           pipeline.unload_lora_weights()

Iteration 1:
Filename: /home/ubuntu/lora-test/lora.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    28   6739.7 MiB   6739.7 MiB           1   @profile
    29                                         def lora_load(path):
    30   6742.6 MiB      2.9 MiB           1           pipeline.load_lora_weights(path)
    31   6741.7 MiB     -0.9 MiB           1           pipeline.unload_lora_weights()

Iteration 2:
Filename: /home/ubuntu/lora-test/lora.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    28   6741.7 MiB   6741.7 MiB           1   @profile
    29                                         def lora_load(path):
    30   6742.5 MiB      0.9 MiB           1           pipeline.load_lora_weights(path)
    31   6742.5 MiB      0.0 MiB           1           pipeline.unload_lora_weights()

Iteration 3:
Filename: /home/ubuntu/lora-test/lora.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    28   6742.5 MiB   6742.5 MiB           1   @profile
    29                                         def lora_load(path):
    30   6742.6 MiB      0.1 MiB           1           pipeline.load_lora_weights(path)
    31   6742.6 MiB      0.0 MiB           1           pipeline.unload_lora_weights()

Iteration 4:
Filename: /home/ubuntu/lora-test/lora.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    28   6742.6 MiB   6742.6 MiB           1   @profile
    29                                         def lora_load(path):
    30   6742.7 MiB      0.1 MiB           1           pipeline.load_lora_weights(path)
    31   6742.7 MiB      0.0 MiB           1           pipeline.unload_lora_weights()

Iteration 5:
Filename: /home/ubuntu/lora-test/lora.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    28   6742.7 MiB   6742.7 MiB           1   @profile
    29                                         def lora_load(path):
    30   6743.5 MiB      0.8 MiB           1           pipeline.load_lora_weights(path)
    31   6743.5 MiB      0.0 MiB           1           pipeline.unload_lora_weights()

Iteration 6:
Filename: /home/ubuntu/lora-test/lora.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    28   6743.5 MiB   6743.5 MiB           1   @profile
    29                                         def lora_load(path):
    30   6743.6 MiB      0.1 MiB           1           pipeline.load_lora_weights(path)
    31   6743.6 MiB      0.0 MiB           1           pipeline.unload_lora_weights()

Iteration 7:
Filename: /home/ubuntu/lora-test/lora.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    28   6743.6 MiB   6743.6 MiB           1   @profile
    29                                         def lora_load(path):
    30   6743.6 MiB      0.1 MiB           1           pipeline.load_lora_weights(path)
    31   6743.6 MiB      0.0 MiB           1           pipeline.unload_lora_weights()

Iteration 8:
Filename: /home/ubuntu/lora-test/lora.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    28   6743.6 MiB   6743.6 MiB           1   @profile
    29                                         def lora_load(path):
    30   6743.7 MiB      0.1 MiB           1           pipeline.load_lora_weights(path)
    31   6743.7 MiB      0.0 MiB           1           pipeline.unload_lora_weights()

Iteration 9:
Filename: /home/ubuntu/lora-test/lora.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    28   6743.7 MiB   6743.7 MiB           1   @profile
    29                                         def lora_load(path):
    30   6743.7 MiB     -0.1 MiB           1           pipeline.load_lora_weights(path)
    31   6743.7 MiB      0.0 MiB           1           pipeline.unload_lora_weights()

System Info

Versions:

python: 3.10.12 diffusers: 0.29.0 torch: 2.3.1 accelerate: 0.31.0 peft: 0.11.1 memory_profiler: 0.61.0 transformers: 4.41.2 safetensors: 0.4.3

Who can help?

@sayakpaul

sayakpaul commented 3 weeks ago

This thread is related: https://github.com/huggingface/peft/issues/1639. Cc: @BenjaminBossan

danishboy000 commented 3 weeks ago

This thread is related: huggingface/peft#1639. Cc: @BenjaminBossan Thanks @sayakpaul just to be clear, my issue is related to CPU memory (RAM) and not GPU memory .

BenjaminBossan commented 2 weeks ago

Hmm, I wonder if it's really related or not, because the latest finding in that thread seems to be that memory has stabilized but latency goes up. Here we still see memory going, although only very slowly and not linearly.

danishboy000 commented 2 weeks ago

@BenjaminBossan the other thread is related to GPU memory while I'm facing this issue in CPU memory (RAM)

BenjaminBossan commented 2 weeks ago

the other thread is related to GPU memory while I'm facing this issue in CPU memory (RAM)

The code you showed is, however, using the GPU, or am I misunderstanding something? Also, could you check if the leak goes away if you call:

import gc
torch.cuda.empty_cache()
gc.collect()

When it comes to pure CPU usage, I didn't find a general trend of increasing memory. For this, I modified the script like so:

import os
from diffusers import DiffusionPipeline
import torch
import psutil

process = psutil.Process(os.getpid())

model_id = "runwayml/stable-diffusion-v1-5"
pipeline = DiffusionPipeline.from_pretrained(
    model_id, torch_dtype=torch.float16, variant="fp16",
)

def lora_load(path, name):
    pipeline.load_lora_weights(path, adapter_name=name)
    pipeline.unload_lora_weights()

for i in range(10):
    print(f"Iteration {i}:")
    path = "markredito/FilmGrain-LoRA-stablediffusion"
    name = f"adapter_{i}"
    lora_load(path, name)
    print(f"Memory Usage: {process.memory_info().rss / (1024 * 1024)} MB")

The output is:

Iteration 0:
Memory Usage: 423.7109375 MB
Iteration 1:
Memory Usage: 424.28515625 MB
Iteration 2:
Memory Usage: 423.40625 MB
Iteration 3:
Memory Usage: 423.98046875 MB
Iteration 4:
Memory Usage: 421.42578125 MB
Iteration 5:
Memory Usage: 425.171875 MB
Iteration 6:
Memory Usage: 426.34765625 MB
Iteration 7:
Memory Usage: 423.7890625 MB
Iteration 8:
Memory Usage: 421.2265625 MB
Iteration 9:
Memory Usage: 422.59375 MB

So the memory usage fluctuates a bit, but there is no general upwards trend.

When I comment the line pipeline.unload_lora_weights(), I observe a steady increase:

Iteration 0:
Memory Usage: 424.91015625 MB
Iteration 1:
Memory Usage: 440.0078125 MB
Iteration 2:
Memory Usage: 456.50390625 MB
Iteration 3:
Memory Usage: 470.8203125 MB
Iteration 4:
Memory Usage: 487.28515625 MB
Iteration 5:
Memory Usage: 501.32421875 MB
Iteration 6:
Memory Usage: 516.5234375 MB
Iteration 7:
Memory Usage: 532.6875 MB
Iteration 8:
Memory Usage: 547.30078125 MB
Iteration 9:
Memory Usage: 562.5390625 MB

Therefore, unloading seems to be doing what it's supposed to do.

danishboy000 commented 2 weeks ago

Yes, I am aware that adding load_lora_weights increases GPU memory, and that is working well for me. However, I have also observed an increase in CPU memory in production, which handles nearly 10,000 requests per day.

I have tested your code for between 500 and 1,000 iterations using various LoRA files, and it has consistently shown an upward trend in CPU memory usage. The increase is not linear, and I noticed varying increments with different LoRA files. For example, you might try running this lora file (https://civitai.com/models/470073/popyays-epic-fantasy-style-or-pony-and-sdxl?modelVersionId=560995) with SDXL 1.0 as the base model. Here is the result that I obtained:

memory_usage

Another example, however the increase is insignificant, but on large scale in production this leads to churned nodes.

330898a4-c996-4a8e-b63d-23c8282ba795

BenjaminBossan commented 2 weeks ago

Thanks for providing more details. Just to be completely clear, so that I can attempt to replicate:

  1. The model is on GPU but you observe a memory leak on CPU RAM.
  2. Each, or most, requests in your scenario involves unloading the existing LoRA and loading a new LoRA (how many distinct adapters do you have?).
  3. To consistently see the effect, we need to check >100 requests.
  4. Did you also observe this with other models or did you only test SDXL?
danishboy000 commented 2 weeks ago

Thanks for providing more details. Just to be completely clear, so that I can attempt to replicate:

1. The model is on GPU but you observe a memory leak on CPU RAM.

2. Each, or most, requests in your scenario involves unloading the existing LoRA and loading a new LoRA (how many distinct adapters do you have?).

3. To consistently see the effect, we need to check >100 requests.

4. Did you also observe this with other models or did you only test SDXL?
  1. Yes, and to be more precise I'm using g5dn.xlarge instance on AWS.
  2. Yes, I used the code that you shared in your last message. Just increased the iterations and changed the base model and lora file.
  3. Yes, 500-1000 iterations will clearly show the upward trend.
  4. Yes, even with SD1.5 and the lora file you shared in your code. When I ran it for 500 iterations it showed 15-20MB overall increase. Some lora files showed the increment of upto 100MB for 500 iterations.
danishboy000 commented 2 weeks ago

Here I've run your exact same code, but for 500 iterations

graph

BenjaminBossan commented 1 week ago

I could reproduce these results. However, I don't think they're related to diffusers or PEFT. Instead, they seem to be related to safetensors. To demonstrate, the snippet below shows the same increase for me:

import gc
from safetensors.torch import load_file

path = "filmgrain_v1.safetensors"
for i in range(500):
    sd = load_file(path)
    del sd
    gc.collect()
    mem = process.memory_info().rss
    print(f"Memory Usage: {mem / (1024 * 1024)} MB")

When I convert the adapter to a PyTorch checkpoint and use torch.load instead, the memory is constant:

path = "filmgrain_v1.bin"
for i in range(500):
    sd = torch.load(path)
    del sd
    gc.collect()  # not even necessary
    mem = process.memory_info().rss
    print(f"Memory Usage: {mem / (1024 * 1024)} MB")

This leads me to conclude that this is a safetensors issue.