huggingface / transformers

πŸ€— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.72k stars 26.94k forks source link

Quantized T5EncoderModel cannot be removed from VRAM on CUDA systems #31479

Closed lstein closed 3 months ago

lstein commented 4 months ago

System Info

Who can help?

@ArthurZucker @SunMarc I am adding support for StableDiffusion3 to the InvokeAI project (https://www.invoke.com/). This task requires the various pipeline components, including T5EncoderModel to be loaded and unloaded from VRAM sequentially.

When quantizing a pretrained T5EncoderModel using BitsAndBytes load_in_8bit=True following the instructions at https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_3, the resulting model gets β€œstuck” in CUDA VRAM and cannot be removed even after deleting all references to the model. See the code below for an illustration of the problem.

This happens as soon as the quantized model is loaded, and isn’t related to running inference with the model. It does not happen with other text encoder models, such as CLIPTextModelWithProjection, where VRAM usage returns to zero after the final object reference is removed and the garbage collection called.

Information

Tasks

Reproduction

Here is a script that illustrates the problem:


import gc
import torch
from transformers import T5EncoderModel, BitsAndBytesConfig

FULL_MODEL = 'stabilityai/stable-diffusion-3-medium-diffusers'

print("* With unquantized model *")
model = T5EncoderModel.from_pretrained(FULL_MODEL,
                                       torch_dtype = torch.float16,
                                       subfolder='text_encoder_3'
                                       ).to('cuda')

print('After loading, VRAM usage=',torch.cuda.memory_allocated())

referrers = gc.get_referrers(model)
print('Referrers = ',len(referrers))

del model
print('After model deletion, VRAM usage=:',torch.cuda.memory_allocated())

gc.collect()
torch.cuda.empty_cache()
print('After gc_collect and empty_cache, VRAM usage=',torch.cuda.memory_allocated())

print("\n* With quantized model *")
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = T5EncoderModel.from_pretrained(FULL_MODEL,
                                       torch_dtype = torch.float16,
                                       subfolder='text_encoder_3',
                                       quantization_config=quantization_config,
                                       low_cpu_mem_usage=True,
                                       )
print('After loading, VRAM usage=',torch.cuda.memory_allocated())

referrers = gc.get_referrers(model)
print('Referrers = ',len(referrers))

del model
print('After model deletion, VRAM usage=',torch.cuda.memory_allocated())

gc.collect()
torch.cuda.empty_cache()
print('After gc_collect and empty_cache, VRAM usage=',torch.cuda.memory_allocated())

Expected behavior

The expected behavior is for the last line to print: After gc_collect and empty_cache, VRAM usage=0

But the observed behavior is:

* With unquantized model *
Downloading shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00, 3250.14it/s]
Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00,  4.53it/s]
After loading, VRAM usage= 11538935808
Referrers =  1
After model deletion, VRAM usage=: 0
After gc_collect and empty_cache, VRAM usage= 0

* With quantized model *
Downloading shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00, 19108.45it/s]
Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:01<00:00,  1.05it/s]
After loading, VRAM usage= 7918596096
Referrers =  7
After model deletion, VRAM usage= 7918596096
After gc_collect and empty_cache, VRAM usage= 7918596096
SunMarc commented 4 months ago

Hi @lstein, thanks for reporting ! This issue is similar to https://github.com/huggingface/transformers/issues/21094. If you do the following, the memory should be freed:

print("\n* With quantized model *")
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = T5EncoderModel.from_pretrained(FULL_MODEL,
                                       torch_dtype = torch.float16,
                                       subfolder='text_encoder_3',
                                       quantization_config=quantization_config,
                                       low_cpu_mem_usage=True,
                                       )
print('After loading, VRAM usage=',torch.cuda.memory_allocated())

referrers = gc.get_referrers(model)
print('Referrers = ',len(referrers))

model = None
print('After model deletion, VRAM usage=',torch.cuda.memory_allocated())

gc.collect()
torch.cuda.empty_cache()
print('After gc_collect and empty_cache, VRAM usage=',torch.cuda.memory_allocated())

LMK if this works on your side !

lstein commented 4 months ago

Hi @SunMarc, thanks so much for the rapid response! Unfortunately this doesn’t seem to work on my side. Can I confirm that the only difference in the proposed solution is to replace del model with model = None?

Here is the output from a cut-and-paste of the proposed solution:

* With quantized model *
Downloading shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00, 6781.41it/s]
Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:01<00:00,  1.06it/s]
After loading, VRAM usage= 7918596096
Referrers =  7
After model deletion, VRAM usage= 7918596096
After gc_collect and empty_cache, VRAM usage= 7918596096

I’ll check out the other issue thread to try to understand the problem better.

lstein commented 4 months ago

I’ve tried the various strategies described in Issue #21094, including calling accelerate.release_memory(), but without success so far. Also no change when explicitly loading with device_map=β€œauto” or device_map=β€œcuda”.

SunMarc commented 4 months ago

Yeah that's right. I replaced del modelwith model = None and it worked on my side. Here's my output:

* With quantized model *
`low_cpu_mem_usage` was None, now set to True since model is quantized.
Downloading shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00, 18040.02it/s]
Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:21<00:00, 10.75s/it]
After loading, VRAM usage= 7918596096
After model deletion, VRAM usage= 7918596096
After gc_collect and empty_cache, VRAM usage= 0

Just to be sure, when checking nvidia-smi, the memory is not freed, is that right ?

lstein commented 4 months ago

I'm glad to hear it is working on your end! There must be some difference in our environments. Could you let me know what versions of Python, transformers, torch, accelerate and CUDA you're using? My info is at the top, except for the CUDA library, which is version 12.2.

Yes, I've confirmed with nvidia-smi that the memory is indeed allocated and used until the process ends.

I appreciate your working through this with me. It's become a bit of a blocker.

Interestingly, this almost works:

model = T5EncoderModel.from_pretrained(...)

state_dict = model.state_dict()
for k, v in state_dict.items():
    state_dict[k] = None

model = None
gc.collect()
torch.cuda.empty_cache()
print('After gc_collect and empty_cache, VRAM usage=',torch.cuda.memory_allocated())

The last line prints out VRAM usage= 8192. So just the model parameters are left.

SunMarc commented 4 months ago

Hi @lstein, I'm on the main branch of transformers and accelerate. As for torch, it's v2.3.0 and my cuda version 12.2 from nvidia-smi.

lstein commented 4 months ago

H’mmm. Can’t understand why I’m having the problem, then.

github-actions[bot] commented 3 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.