huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.09k stars 27.03k forks source link

LLM during inference do not deallocate memory #31519

Closed Tomas542 closed 4 months ago

Tomas542 commented 5 months ago

System Info

Who can help?

@ArthurZucker @younesbelkada @zucchini-nlp

Information

Tasks

Reproduction

Init model

base_model = 'google/flan-t5-xxl'
ckpt = './results/checkpoints_t5_1/checkpoint-4600'
device = 'cuda:0'

tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)

model = AutoModelForSeq2SeqLM.from_pretrained(
    base_model,
    device_map=device,
    max_memory={0:"60GB"},
    trust_remote_code=True,
    torch_dtype=torch.float16,
    offload_state_dict=True,
)

model = PeftModel.from_pretrained(
    model,
    ckpt
)

model.eval()
model = torch.compile(model)

Create generation config

generation_config = GenerationConfig(
    do_sample=True,
    temperature=0.8,
    top_p=0.75,
    top_k=40,
    num_beams=4,
    max_new_tokens=224,
    stream_output=False,
    model=model,
    use_cache=False,
)

Generate text

inputs = create_prompt(example)
input_ids = inputs['input_ids']
input_ids = input_ids.to(device)
gt = example['output']

with torch.inference_mode():
    generation_output = model.generate(
        input_ids=input_ids,
        generation_config=generation_config,
        return_dict_in_generate=False,
        output_scores=False,
    )

And track memory with this function:

def show_gpu(msg):
    """
    ref: https://discuss.pytorch.org/t/access-gpu-memory-usage-in-pytorch/3192/4
    """
    def query(field):
        return(subprocess.check_output(
            ['nvidia-smi', f'--query-gpu={field}',
                '--format=csv,nounits,noheader'], 
            encoding='utf-8'))
    def to_int(result):
        return int(result.strip().split('\n')[0])

    used = to_int(query('memory.used'))
    total = to_int(query('memory.total'))
    pct = used/total
    print('\n' + msg, f'{100*pct:2.1f}% ({used} out of {total})')

After initialization we've used: GPU 32.1% (26157 out of 81559) Than firts, second and third generations:

GPU 63.0% (51387 out of 81559)
GPU 63.0% (51389 out of 81559)
GPU 93.9% (76613 out of 81559)

And we allocated more memory, than we specified in max_memory

Expected behavior

Model to free memory after generation or not to try to allocate more memory, than we specified in max_memory

ArthurZucker commented 4 months ago

Hey! Was you issue fixed?