huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.09k stars 27.03k forks source link

How to avoid the peak RAM memory usage of a model when I want to load to GPU #28476

Closed JoanFM closed 9 months ago

JoanFM commented 10 months ago

System Info

Who can help?

I am using transformers to load a model into GPU, and I observed that before moving the model to GPU there is a peak of RAM usage that later gets unused. I assume the model is loaded into CPU before moving into GPU.

In GPU model takes around 4Gi and to load it I need more than 7Gi of RAM which seems weird.

Is there a way to load it direcly to the GPU without spending so much RAM?

I have tried with the low_cpu_mem_usage and device_map parameter to cuda and auto but no luck.

from transformers import AutoModel; m = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True, low_cpu_mem_usage=True, device_map="auto")

Information

Tasks

Reproduction

from transformers import AutoModel; m = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True, low_cpu_mem_usage=True, device_map="auto")

Expected behavior

Not having such a memory peak

ArthurZucker commented 10 months ago

Hey 🤗 thanks for opening an issue! I am not sure you can prevent CPU usage (transfer from SSD to GPU), not sure anything supports it. However device_map = "auto" should always allow you to load the model without going over the ram usage. The peak can come from torch. set_default_dtype(torch.float16) and the fact that you are not specifying a dtype. So the model might be loaded in float32, then casted then transfered.

JoanFM commented 10 months ago

so what would you actually suggest to do? What dtype parameter should I pass?

ArthurZucker commented 10 months ago

float16 or something like that. Or use TEI https://huggingface.co/docs/text-embeddings-inference/index

github-actions[bot] commented 9 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

JoanFM commented 9 months ago

Thanks for helping, it was indeed an issue with dtype