Open Ricky-Ting opened 9 months ago
Thanks, let me reproduce and get back to you
really sorry for the delay. I was able to reproduce your issue. not sure what is causing the accumulation, but you can empty the xpu cache between iterations to avoid OOM torch.xpu.empty_cache()
Describe the bug
When
seq_len
get larger and larger, the device memory utilization will get higher and finally gets OOM on Arc770. But if you simply run once inference withseq_len
=1300, the device mem occupied should be around 15.2 GB.related code
```python import torch import intel_extension_for_pytorch as ipex from transformers import AutoModelForCausalLM pretrained = "/mnt/disk1/models/Llama-2-7b-chat-hf/" device = 'xpu' model = AutoModelForCausalLM.from_pretrained(pretrained, trust_remote_code=True, use_cache=True) model = model.half().to(device) with torch.inference_mode(): for seq_len in range(1200, 1301, 5): input_ids = torch.randint(5, 2000, (1, seq_len)).to(device) attention_mask = torch.ones_like(input_ids).to(device) print(input_ids.shape) generations = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=256, do_sample=False, ) torch.xpu.synchronize() ```
Outputs
``` torch.Size([1, 1205]) torch.Size([1, 1210]) torch.Size([1, 1215]) torch.Size([1, 1220]) torch.Size([1, 1225]) torch.Size([1, 1230]) torch.Size([1, 1235]) torch.Size([1, 1240]) torch.Size([1, 1245]) torch.Size([1, 1250]) torch.Size([1, 1255]) Traceback (most recent call last): File "/home/arda/baorong/tmp/generate.py", line 21, in
Versions