abertsch72 / unlimiformer

Public repo for the NeurIPS 2023 paper "Unlimiformer: Long-Range Transformers with Unlimited Length Input"
MIT License
1.05k stars 77 forks source link

Sanity check: VRAM usage on llama-2-7b-chat-hf higher than without Unlimiformer on low tokens? #26

Open SharkWipf opened 1 year ago

SharkWipf commented 1 year ago

I'm trying out the new Unlimiformer llama-2 code on llama-2-7b-chat-hf, on a 24GB 3090. I understand Unlimiformer probably wasn't created with consumer GPUs in mind, but I'd hoped I'd be able to squeeze some more context out of my GPU locally before having to resort to expensive cloud GPUs. I managed to get everything working, but the VRAM usage per token seems to be higher than on stock llama-2-7b-hf. I imagine there is expected overhead from running Unlimiformer, though it is more than I expected. With vanilla Transformers (same versions and everything) on fp16, I can ingest up to ~5350 tokens at once before running out of memory. With Unlimiformer, 5350 tokens runs out of memory, and I can barely do more than 4096 tokens (5000 OOMs). Is this expected overhead? And is this overhead fixed, or does it vary with the model size?

Semi-related side-questions: Is there anything Unlimiformer does that would prevent it from working with bitsandbytes 8/4 bit quanitzation, or should that be a matter of simply enabling it? And should training QLoRA with peft work?

SharkWipf commented 1 year ago

I should add, I'm running it like this:

python src/run_generation.py --model_type llama --model_name_or_path meta-llama/Llama-2-7b-chat-hf --prompt example_inputs/harry_potter_notfull.txt --test_unlimiformer --fp16 --length 200 --layer_begin 22 --use_datastore True --gpu_datastore False --index_devices 0 --datastore_device 0
abertsch72 commented 1 year ago

Thanks for your interest! Unfortunately there is some overhead from running Unlimiformer because of our datastore construction, even at the same length of input.

When you say you're ingesting 5350 tokens with vanilla transformers, are you passing all tokens to the model without truncation? Llama-2 has a context window of 4096, right?

For Unlimiformer, the cost-per-token varies according to the size of the model's embeddings and (for decoder-only models) the number of layers that you apply Unlimiformer at. Since you're applying Unlimiformer at the last 10 layers and the model has a base context length of 4096, you're storing an additional 10*(n-4096) tokens when you pass an input of length n>4096.

There's a few things you could try to reduce this:

This issue discusses using quantization. If you follow Uri's recommendation there (set model_clone = model instead of cloning the model), are you able to use bitsandbytes in your setup? There's no conceptual reason it should fail, but we haven't tried this personally.

Likewise, I can't think of a reason that QLoRA would fail, though we haven't tried it. If you do, please let us know how it goes!

SharkWipf commented 1 year ago

Thanks for your extensive reply.

When you say you're ingesting 5350 tokens with vanilla transformers, are you passing all tokens to the model without truncation? Llama-2 has a context window of 4096, right?

That's right, I'm just primitively shoving everything in there, and the output is fairly garbage as expected, but I was just trying to find the limit I could get in VRAM in this case. Interestingly, the exact limit with Unlimiformer seems to be 4098 tokens when it runs out of memory, with this explanation I'd expect it to run out of memory at 4097 tokens? :thinking: Probably just something with how I'm truncating. (On closer inspection, it seems Unlimiformer is consistently picking up one token less than what I truncate at, not sure why but at least that checks out.)

For reference, my rather primitive code I used, both for testing without Unlimiformer and for generating the prompt for Unlimiformer (no need to dive into it or anything though):

from transformers import AutoTokenizer, AutoModelForCausalLM

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")

# Convert model to fp16
model = model.half().cuda()

# Read file content as prompt
with open('example_inputs/harry_potter_full.txt', 'r') as file:
    prompt = file.read()

# Check if the prompt is too long, and if so, truncate or split it
MAX_TOKENS = 4097  # 4098 OOMs on Unlimiformer, but 4097 works?
tokenized_prompt = tokenizer.encode(prompt, return_tensors="pt")
if len(tokenized_prompt[0]) > MAX_TOKENS:
    print(f"Warning: Truncating prompt to {MAX_TOKENS} tokens.")
    tokenized_prompt = tokenized_prompt[:, :MAX_TOKENS]

# Convert the truncated, tokenized text back to its text form
truncated_text = tokenizer.decode(tokenized_prompt[0], skip_special_tokens=True)

# Write the truncated text to a different file
with open('example_inputs/harry_potter_notfull.txt', 'w') as file:
    file.write(truncated_text)

tokenized_prompt = tokenized_prompt.cuda()

# Generate text
output = model.generate(tokenized_prompt, max_length=tokenized_prompt.shape[1] + 100, temperature=1.0)  # +100 is arbitrary, adjust as desired

# Decode the generated text
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)

For Unlimiformer, the cost-per-token varies according to the size of the model's embeddings and (for decoder-only models) the number of layers that you apply Unlimiformer at. Since you're applying Unlimiformer at the last 10 layers and the model has a base context length of 4096, you're storing an additional 10*(n-4096) tokens when you pass an input of length n>4096.

That makes sense.

  • try setting --use_datastore=False
  • switch the index to RAM by using the flag --gpu_index=False

With both of these set, I can go past the "stock" transformers limits by a couple hundred tokens, but unfortunately no more than that it seems. Which I guess makes sense.

move the index/datastore to a second GPU, if you have another one: this will slightly impact speed and not affect performance. This could be a <24GB GPU.

I'll have to see if I can get one of my A2000s in this machine, it's pretty weak though so it'll probably not be great.

try swapping the type of index to a more compressed one from this list: this will affect speed and may affect performance. You'd have to make this change in the code in index_building.py here.

I'll have to try this out later. Though if we're already moving the index to RAM, I don't think it would affect VRAM usage anymore at that point, right? Or does it also affect the datastore?

set layer_begin to a later layer or add a layer_end to reduce the number of layers using Unlimiformer: this will increase speed and likely hurt performance, but it's probably the easiest way to cut down on memory.

Yeah, ideally I'd like to move it earlier, not later, heh.

https://github.com/abertsch72/unlimiformer/issues/19 discusses using quantization. If you follow Uri's recommendation there (set model_clone = model instead of cloning the model), are you able to use bitsandbytes in your setup? There's no conceptual reason it should fail, but we haven't tried this personally.

Ah, I even saw that issue before, but forgot about it since I was mostly interested in decoder-only support. I'll give that a try later too.

Likewise, I can't think of a reason that QLoRA would fail, though we haven't tried it. If you do, please let us know how it goes!

And this too.

I'll leave this issue open for a bit while I try to figure out the quantization and QLoRA stuff and will report back with my results (may be a few days though), I imagine it might help other people as well especially now that Llama support is in.

urialon commented 1 year ago

Thanks @SharkWipf , it would certainly help.

SharkWipf commented 1 year ago

I haven't quite gotten around to messing with bitsandbytes and peft yet, but I decided to give the new native Transformers GPTQ support a quick try.

After installing auto-gptq and optimum per their instructions, and changing one line in the example run code, I could successfully run inference on a GPTQ-quantized model from the hub:

pip install auto-gptq optimum --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
...
python src/my_generation.py --model_type llama --model_name_or_path TheBloke/Llama-2-13B-chat-GPTQ --prompt example_inputs/harry_potter_notfull.txt --test_unlimiformer --fp16 --length 200 --layer_begin 22 --use_datastore False --gpu_datastore False --gpu_index False

The only real difference there is the removal of --fp16

The one line I had to modify:

446c446
<     model = model_class.from_pretrained(args.model_name_or_path, **model_kwargs)
---
>     model = model_class.from_pretrained(args.model_name_or_path, **model_kwargs, torch_dtype=torch.float16, device_map="auto")

(This is not a clean fix, just a quick-and-dirty workaround for GPTQ support)

I have not yet tested the limits of this approach, it's been a bit chaotic here, I do intend to do further testing when I get around to it, of Unlimiformer itself, GPTQ, Bitsandbytes, PEFT and more, just figured I'd share the GPTQ route for now in case anyone wants to know.

EDIT: Some quick numbers, because I couldn't resist, on my 3090 (24GB VRAM):

Overall the encoding step seems to be taking most VRAM by far in general, and tends to be what runs out of memory fastest. I assume this much is expected.

Output quality is kind of, well, garbage on all of these however, random words and whitespace, mostly using the same terminology as the input, but this seems to happen with or without Unlimiformer. Hopefully PEFT finetuning will be able to help there. I'd also like to see if I can get Unlimiformer to go along with RoPE, but I haven't looked into this at all yet. EDIT2: Seems like Llama-2 comes with RoPE by default, so there's probably nothing to enable. Tweaking RoPE settings doesn't break anything any further, but it also doesn't seem to offer any additional benefits.

urialon commented 1 year ago

Wow, thanks a lot @SharkWipf .

If you manage to run more experiments we (and I'm sure that other users) would love to hear about them.