turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.23k stars 238 forks source link

Clear cache to avoid OOM with iterative generation #291

Closed cdreetz closed 3 weeks ago

cdreetz commented 5 months ago

I am running a 3.0bpw mixtral on a 3090, so just enough room to fit, but I noticed increasing vram through a generating loop until eventually hitting OOM. Don't really need to hold kv cache since each generation is independent of one another, so can just clear it with each call. Code provided below

config.model_dir = "/mnt/models/"
config.prepare()

model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy = True)

model.load_autosplit(cache)

tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
generator.set_stop_conditions([tokenizer.eos_token_id])
gen_settings = ExLlamaV2Sampler.Settings()

def process_with_language_model(input_text):
    cache = ExLlamaV2Cache(model, lazy = True)
    cache.current_seq_len = 0
    instruction_ids = tokenizer.encode(f"[INST] {input_text} [/INST]", add_bos = True)
    context_ids = instruction_ids if generator.sequence_ids is None \
        else torch.cat([generator.sequence_ids, instruction_ids], dim = -1)
    generator.begin_stream(context_ids, gen_settings)

    output_text = ""
    while True:
        chunk, eos, _ = generator.stream()
        if eos: break
        output_text += chunk
        print(chunk, end="")
        sys.stdout.flush()

    return output_text.strip()
turboderp commented 5 months ago

If you keep generating with no token limit you'll eventually fill up the cache, but at that point you should get an overflow error, not OoM. Calling generator.begin_stream will reset the cache.

If repeatedly calling this function leads to responses of less than 32k tokens in total (the default for Mixtral) but you're still getting OoM at some point, there could be a memory leak somewhere, but the cache is static and doesn't grow with the sequence length.

It could be that you're right on the cusp of running out of VRAM and something other than exllamav2 (like your window manager or whatever) is allocating a little too much in the background. You could try setting config.max_seq_len = 20000 or something, right after config.prepare() which could free up some memory overall.

These two lines:

    cache = ExLlamaV2Cache(model, lazy = True)
    cache.current_seq_len = 0

have no effect since the generator is created with the first cache object. With lazy = True no VRAM is allocated, though, instead deferring to the autosplit model loader which isn't called subsequently.

cdreetz commented 5 months ago

Ah it might be that overflow, I didn't see the OoM just noticed when it stopped generating I was looking at runpod and it showed gpu util at 99% so I assumed it was OoM. I'll try it again with the generator.begin_stream. Also I don't need to set the cache seq len to 0 with the lazy mode?

turboderp commented 5 months ago

You don't need to create a second instance of the cache at all. There's only one cache used by the generator and it's the one you're creating on line 4 of your example, to pass to the ExLlamaV2StreamingGenerator constructor.

Lazy mode just creates the cache without allocating the cache tensors in VRAM. This is for use with the load_autosplit function, which allocates cache layers as it's loading the model, to keep everything aligned across multiple GPUs without knowing in advance which layers will end up on which device.

To actually reset the cache you would indeed set current_seq_len = 0, but this done implicitly by generator.begin_stream anyway, which starts a generation with the provided token IDs which in turn are inserted into the context starting from position zero.

bjj commented 5 months ago

I have found repeated generation causes VRAM usage to creep up. I'm doing:

    gc.collect()
    torch.cuda.empty_cache()

after generations, which keeps VRAM usage consistent but does cost a bit of perf. But the alternative is having VRAM climb until it hits "shared" memory and then perf goes off a cliff.

turboderp commented 5 months ago

This does look like a memory management issue in PyTorch. Manually calling gc.collect() and torch.cuda.empty_cache() aren't supposed to do anything, and it's kind of a mystery why it sometimes works, since all the functionality should automatically be invoked when it's needed.

With the shared memory option in later NVIDIA drivers, it's possible that Torch simply doesn't recognize that it's running out of memory and doesn't properly manage its cache since it sees more VRAM than what's physically available. Another reason to disable that evil feature, I guess?

bjj commented 5 months ago

I found that a big part of my problem was that variable scope in python is pretty broad. I had things like caches = [_.cache for _ in work] and even if that went out of "scope" (in the C++ sense I'm used to thinking in), it was still alive under python rules. That meant I often held references to caches and sampler states when allocating new caches (or unloading a model and loading a new one). So I was getting fragmentation and unnecessarily high peak memory usage.

Just going through the code and whacking temporaries to None goes a long way to keeping reserved memory from growing.