huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.83k stars 26.97k forks source link

MixTral 8*7B GPU Memory usage keeps increasing during inference #28707

Closed oroojlooy closed 8 months ago

oroojlooy commented 9 months ago

System Info

The machine include 8*A100-40Gb,

Who can help?

@Narsil @SunMarc

Information

Tasks

Reproduction

I am creating an instance of the MixTralModel class and call it in a loop with the prompts that I have.

import transformers
import torch

class MixTralModel:
    def __init__(self, temperature=0.0, max_new_tokens=356, do_sample=False, top_k=50, top_p=0.7):
        self.temperature = temperature
        self.max_new_tokens = max_new_tokens
        self.do_sample = do_sample
        self.top_k = top_k
        self.top_p = top_p

        if do_sample and temperature == 0.0:
            raise ValueError(
                "`temperature` (=0.0) has to be a strictly positive float, otherwise your next token scores will be "
                "invalid. If you're looking for greedy decoding strategies, set `do_sample=False`")
        self.pipeline = transformers.pipeline(
            "text-generation",
            model="mistralai/Mixtral-8x7B-Instruct-v0.1",
            device_map="auto",
            # device="cuda:0",
            # model_kwargs={"torch_dtype": torch.float16, "load_in_4bit": True},
            model_kwargs={"torch_dtype": torch.float16},
        )

    def __call__(self, raw_messages: str) -> str:
        """
        An example of message is:
        messages = [{"role": "user", "content": "Explain what a Mixture of Experts is in less than 100 words."}]
        """
        try:
            messages = [{"role": "user", "content": raw_messages}]
            prompt = self.pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            outputs = self.pipeline(prompt, max_new_tokens=self.max_new_tokens, do_sample=self.do_sample,
                                    temperature=self.temperature, top_k=self.top_k, top_p=self.top_p)
            return outputs[0]["generated_text"]
        except Exception as e:
            print(e)            

if __name__ == "__main__":
    model = MixTralModel(temperature=0.0, max_new_tokens=356, do_sample=False, top_k=50, top_p=0.7)
    messages = "Explain what a Mixture of Experts is in less than 100 words."
    out = model(messages)
    print(out)

Expected behavior

When I call the instance of the above class with my data, the GPU memory keeps increasing over time until I get a CUDA memory error. It seems there is a memory leakage or it maybe keeps the gradient (?) on the memory.

The memory jumps over by 3Gb on each core each time. For example, below shows the gpu memory usage before and after a jump:


+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage  |
|=======================================================================================|
|    0   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python    18234MiB |
|    1   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python    20764MiB |
|    2   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python    20764MiB |
|    3   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python    20764MiB |
|    4   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python    20764MiB |
|    5   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python    20764MiB |
|    6   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python    15430MiB |
|    7   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python      414MiB |
+---------------------------------------------------------------------------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage  |
|=======================================================================================|
|    0   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python    22230MiB |
|    1   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python    24760MiB |
|    2   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python    24762MiB |
|    3   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python    24760MiB |
|    4   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python    24760MiB |
|    5   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python    24760MiB |
|    6   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python    19426MiB |
|    7   N/A  N/A     93020      C   .../miniconda3/envs/mixtral/bin/python      414MiB |
+---------------------------------------------------------------------------------------+

Note that this does not happen in each call of the model, and overall it gets killed after about 120 calls.

amyeroberts commented 9 months ago

Hi @oroojlooy, thanks for raising this issue!

I believe there's an accumulation of gradients happening due to the multiple forward passes on the model.

Putting the pipeline calls in the torch.no_grad context should help:

with torch.no_grad():
    outputs = self.pipeline(prompt, max_new_tokens=self.max_new_tokens, do_sample=self.do_sample, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p)

As a side note - if do_sample=False, then parameters like top_p won't have any effect on the generation.

oroojlooy commented 9 months ago

Hi @amyeroberts Thanks for your reply. I tried adding torch.no_grad, it does not help and the memory keeps increasing. I also tried running the model via TGI which is supposed to manage the process efficiently, still the memory increases with that; although, no CUDA memory error when I use that. The CUDA memory usage with TGI increases up to around 39Gb on all cores and stays there.

Could not be this related to the structure of MixTral 8*7B, where it has eight expert models? i.e., I observe a memory jump when one of the expert models gets loaded?

Also thanks for the tip about do_sample!

amyeroberts commented 9 months ago

Hi @oroojlooy,

You'll need to provide some more details about how the memory increases, in particular for the non-TGI case: is there a sudden spike? Does it go up and down? After how many calls do you see this increase?

Note: in the case of generation, you're making autoregressive calls to the model i.e. the model is being repeatedly called with an increasing input length. If options such as use_cache aren't selected, then you would expect a memory increase, even after the model has been loaded.

I'd suggest not using the pipeline, and using the modeling code directly. This will give you more control and enable you to monitor better what is causing increases in memory.

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

class GenerationModel:
    def __init__(self, model_id, temperature=0.0, max_new_tokens=356, do_sample=False, top_k=50, top_p=0.7):
        self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.temperature = temperature
        self.max_new_tokens = max_new_tokens
        self.do_sample = do_sample
        self.top_k = top_k
        self.top_p = top_p

        if do_sample and temperature == 0.0:
            raise ValueError(
                "`temperature` (=0.0) has to be a strictly positive float, otherwise your next token scores will be "
                "invalid. If you're looking for greedy decoding strategies, set `do_sample=False`")

    def __call__(self, raw_messages: str) -> str:
        """
        An example of message is:
        messages = [{"role": "user", "content": "Explain what a Mixture of Experts is in less than 100 words."}]
        """
        try:
            messages = [{"role": "user", "content": raw_messages}]
            prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
            with torch.no_grad():
                outputs = self.model.generate(**inputs, max_length=len(prompt[0]) + self.max_new_tokens, use_cache=True)

            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            return generated_text
        except Exception as e:
            print(e)

if __name__ == "__main__":
    model = GenerationModel(
        model_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
        temperature=0.0,
        max_new_tokens=356,
        do_sample=False,
    )
    messages = "Explain what a Mixture of Experts is in less than 100 words."
    out = model(messages)
    print(out)

cc @gante The generation wizard who will know more about getting this to run well

Could not be this related to the structure of MixTral 8*7B, where it has eight expert models? i.e., I observe a memory jump when one of the expert models gets loaded?

You can run an experiment with a non-MoE model and see :)

gante commented 9 months ago

Hi @oroojlooy :wave:

As @amyeroberts wrote, the memory consumption in transformers is expected to grow throughout generation (i.e. the pipeline call in your script), as the input/output grows longer. This is because we don't pre-allocate memory, contrarily to TGI (that's why you see a fixed memory footprint after the model gets loaded). It is also independent of being a MoE model, it's how text generation works.

To confirm that there is no memory leak, you can try a simple test: call your pipeline repeatedly with the same input and with do_sample=False. You should not see memory increases as you repeat the calls.

oroojlooy commented 9 months ago

Thanks @amyeroberts and @gante for the replies. @amyeroberts I tried non-MoE model as you suggested and I. can confirm what you pointed out! @gante I actually have tried what you suggested and did not see any memory jump.

But, I am still confused about how the inference works within language models. I understanding was, for a fixed batch-size, the memory usage of the network should be fixed. Because, regardless of how big is an input+output, the max-len of input+output is always capped by the context size of the LLM, which a memory is allocated for that we load the model. So, for batch-size b, the memory utilization of the model is equal to dtype_size*b*num_params + dtype_size*num_operations. Should not this memory-size utilization be fixed through the inference time?

gante commented 9 months ago

@oroojlooy to understand why the memory grows (and what you can do about it), have a look at this guide -- especially section 2, which covers the self-attention layer :)

oroojlooy commented 9 months ago

@gante So, if I understand correctly, the matrix QK^T with different size of N^2 stays at cache, and that cause the surge of memory usage?

gante commented 9 months ago

@oroojlooy There are sources of memory requirements increase as the sequence length (N) increases when caches are used:

  1. The materialization of the QK^T multiplication, which may grow as quickly as N^2 (flash attention decreases this), as you wrote;
  2. The cached key and values, which grow linearly with N.

transformers is eager, it only allocates when needed. TGI checks the maximum possible memory usage at startup time. Both will have roughly the same peak memory usage, for a given model/maximum input length/attention implementation.

oroojlooy commented 8 months ago

@gante Thanks for all the explanations!

ahmedkotb98 commented 6 months ago

is there any solution for this problem?

gante commented 6 months ago

@ahmedkotb98 do you have any related issue that was not discussed above?

Dr-Left commented 6 months ago

Same issue, same code, for llama-3-70B, fine; for Mixtral-8x7B, fails The memory goes like, (after loading the model) first going up, then drop to a low level(I think it's the level where model was just loaded), and then going up again ... around and around. The inference is super slow, compared to other 70B models

gante commented 6 months ago

@Dr-Left do you have a short reproducible script for the situation you describe + can you share your environment version? It would help us pinpoint issues 🤗