eth-sri / lmql

A language for constraint-guided and efficient LLM programming.
https://lmql.ai
Apache License 2.0
3.65k stars 197 forks source link

[In-Process Models] There's a memory leak in the server #222

Closed freckletonj closed 1 year ago

freckletonj commented 1 year ago

My VRAM continues to fill up over the course of inference (using the server) until OOMing.

I'm not sure what else to say. Just do watch nvidia-smi and watch it fail to free old memory.

I'm not sure the root of this yet.

It'd be nice if we could instantiate a model ourselves, and control these things ourselves.

The inline model feature is frustrating if you're working in an interpreter because you'd need to repeatedly reload a model.

Why not something like guidance, where you load your tokenizer and model, and pass them to lmql?

Eg:

model = AutoModelForCausalLM(...)
tokenizer = AutoTokenizer(...)
model.eval()

lmql = lmql.use_model(model, tokenizer)

torch.empty_cache() # whenever you want
lbeurerkellner commented 1 year ago

Thanks for raising this. I am currently investigating this on my end.

In general, in-process model loading with HuggingFace is difficult to work with, because of long loading times. For this reason, we primarily advise to use lmql serve-model instead. This decouples model inference from the front-end interpreter (see https://docs.lmql.ai/en/latest/language/hf.html). This is also the reason why we currently do not allow users to instantiate the model themselves. Instead, we move it to a background process/thread, decoupling the interpreter from actual inference.

If you need more direct access to the concrete model constructor call, you can have a look at https://github.com/eth-sri/lmql/blob/main/src/lmql/models/lmtp/lmtp_programmatic_serve_example.py, which passes all provided parameters directly to the actual model constructor and spins up an inference server.

Now, we could allow users to provide model instances themselves, but if we use these models synchronously in the same process as the interpreter, each generate() call blocks the entire execution, which is not desirable in highly parallel settings. E.g. you can execute many LMQL queries concurrently, while actual inference is batched on the backend. Instead, under the current architecture, we can run everything in a highly async manner.

I hope this clarifies your questions. I'll be sure to investigate the memory issue.

freckletonj commented 1 year ago

Thanks @lbeurerkellner ! I actually was noticing the leak via the server.

As an aside, servers are great too, but sometimes when I'm developing, I don't mind instantiating a model myself. I typically do something like:

try:
    already_loaded
except:
    model = <...>
    already_loaded = True

This let's me re-eval the entire file without accidentally reinitializing the model.

I understand that wouldn't work well with your parallelism stuff, which is super cool, so no big deal if it has to be via server.

For my use case, I was trying to get AutoGPTQ to work well, which would have been aided by an ability to customize the model more myself. I know lmql claims to support GPTQ, but it wasn't working for me (I'm not recalling why now).

lbeurerkellner commented 1 year ago

I investigated this a bit, running with falcon-7b-instruct and Llama-2-7B (regular transformers and autogptq) via the server.

I ran a script that executes increasingly long queries, with respect to sequence length (see below for the script to reproduce). Maybe you can run this with you configuration, so we can learn a bit more about memory allocation behaviour.

Summarizing the plots below, I find that transformers seems to clean up regularly on its own, while auto-gptq seems to be constant memory use. So far I could not yet detect a memory leak. Still experimenting with different configurations.

lmql: (main)
transformers: 4.33.3
torch: 2.0.0
bitsandbytes: 0.39.0
auto-gptq: 0.3.0

Note that in a recent transformers version, they ported Falcon to a new (official) implementation, that seems to avoid the background threading issues. This also means you can just use lmql.model("local:tiiuae/falcon-7b-instruct", dtype="4bit", cuda=True) now.

Results for 7B models, 4bit quantised (I can't run more on my machine):

Llama-2-7b-chat transformers

lmql serve-model meta-llama/Llama-2-7b-chat-hf --dtype 4bit 

gpu-llama2

Falcon-7b-instruct auto-gptq

lmql serve-model TheBloke/falcon-7b-instruct-GPTQ --static --use_safetensors True --trust_remote_code True --device_map auto --load_in_4bit True  --use_flash_attention_2 True --loader "auto-gptq" --bits 4 --dtype bfloat16

gpu-falcon-7b-autogptq

Falcon-7b-instruct transformers

lmql serve-model tiiuae/falcon-7b-instruct --dtype 4bit --cuda

gpu

Script to create this figures:

import asyncio
import lmql
import time
import torch

from tqdm import tqdm
import pynvml

memory_over_time = []
utilisation_over_time = []
seqlen_over_time = []
query_over_time = []
prompt = ""

query_time = 0

import matplotlib.pyplot as plt

async def monitor(tokenizer):
    """
    Monitors inference (memory use, utilization and prompt token length)
    and plots it to gpu.png.
    """
    pynvml.nvmlInit()
    handle = pynvml.nvmlDeviceGetHandleByIndex(0)

    while True:
        meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
        utilisation = pynvml.nvmlDeviceGetUtilizationRates(handle)

        memory_over_time.append(meminfo.used / 1024 / 1024)
        utilisation_over_time.append(utilisation.gpu)
        seqlen_over_time.append(len(tokenizer(prompt)["input_ids"]))
        query_over_time.append(query_time)

        # plot memory, utilization and sequence length in three subplots side by side
        plt.clf()

        plt.figure(figsize=(16, 16))

        plt.subplot(2, 2, 1)
        plt.plot(memory_over_time)
        plt.title("Memory usage (MB)")
        plt.ylim(0, meminfo.total / 1024 / 1024 + 128)
        plt.subplot(2, 2, 2)
        plt.plot(utilisation_over_time)
        plt.title("GPU utilization (%)")
        plt.ylim(0, 100)
        plt.subplot(2, 2, 3)
        plt.plot(seqlen_over_time)
        plt.title("Sequence length")
        plt.ylim(0, 2048)
        plt.subplot(2, 2, 4)
        plt.plot(query_over_time)
        plt.title("Query time (s)")
        plt.ylim(0, max(query_over_time) * 1.1)

        plt.tight_layout()

        plt.savefig("gpu.png")
        plt.close()

        await asyncio.sleep(1)

@lmql.query(chunksize=16)
async def q(prompt):
    '''lmql
    "{prompt} and[RESULT]" where len(TOKENS(RESULT)) == 16
    return RESULT
    '''

async def main():
    # set to your configuration
    model_name = "tiiuae/falcon-7b-instruct"
    model_object = lmql.model(model_name, <other config>)

    global prompt, query_time
    tokenizer = lmql.tokenizer(model_name)
    prompt = "This is a great day"
    monitor_task = asyncio.create_task(monitor(tokenizer))

    for i in tqdm(range(1000)):
        s = time.time()
        try:
            prompt += (await q(prompt, model=model_object))[0]
        except AssertionError as e:
            if "The decoder returned a sequence that exceeds the provided max_len" in str(e):
                print("Reached max_len, exiting early.")
                break
            else:
                raise e
        query_time = time.time() - s

    monitor_task.cancel()

if __name__ == "__main__":
    asyncio.run(main())
lbeurerkellner commented 1 year ago

One follow up question here: Did you switch model within the same server process, e.g. by running queries for different models against the same endpoint? I remember observing something funky there once.

freckletonj commented 1 year ago

@lbeurerkellner Wow, thank you for the thorough followup!

I had to pull down the latest lmql because lmql.tokenizer didn't exist. I notice the version is still lmql-0.99999, same as I had before, but now the api does expose lmql.tokenizer. (I'm sure you've talked about this already, but perhaps consider semantic versioning, like 0.1.7 etc)

Some environment details:

os = ubuntu 22

>>> torch.version.cuda
'12.1'

>>> torch.__version__
'2.2.0.dev20230926+cu121'

model_name = "/home/user/models/falcon-7b-instruct"  # the latest version of this was pulled down days ago

# running with flash-attention-2
lmql serve-model "/home/josh/_/models/falcon-7b-instruct" --static --trust_remote_code True --device_map auto --dtype float16 --use_flash_attention_2 True

So, upon running this, I do not see a memory leak anymore. I had to update lmql, perhaps something changed?

The behavior previously was: I was working in an interpreter, and struggling to get the syntax of my prompts to work with the lmql parser, and so frequently getting runtime errors, but from usually the client, with a server running the entire time. IIRC the server was sometimes throwing errors too, but I forget what. I didn't see them as relevant to the OOM at the time, but I have noticed a behavior before where errors can cause a model to hang on to memory it would have otherwise released. Perhaps my issue was related to that?

I'm happy to close this for now, and re raise it with more details if I run into it again.

Thanks again for the thorough response!!!

gpu

freckletonj commented 1 year ago

There is something squirly going on. Here's a little chat loop script, and even when the client program closes, the server keeps stuff in memory. Is it past_key_values? I do notice that when it approaches OOMing, it suddenly frees memory. Or could it be doing something like DeepSpeed to swap? I'm not sure, but the tok/s stays fast, so, I think it's freeing memory.

import re
import json
import lmql
import asyncio

def format_transcript(transcript):
    return '\n'.join([f'{x["person"].upper()}: {x["text"]}' for x in transcript])

@lmql.query
def prompt(transcript):
    '''lmql
sample(temperature=1.0, max_len=8000)
"""SYSTEM: You're a helpful assistant.
{transcript}
ASSISTANT: [RESPONSE]
""" where (
    STOPS_BEFORE(RESPONSE, '\n') and len(TOKENS(RESPONSE)) > 100 and len(TOKENS(RESPONSE)) < 2000
    )
from
  "/home/me/models/mistral/Mistral-7B-Instruct-v0.1"
'''

transcript = [
    # {'person':'user', 'text': 'hi'},
    # {'person':'assistant', 'text': 'wow'},
]

inp = ''
while inp.lower().strip() not in {'quit', 'exit', 'close'}:
    inp = input('USER: ')
    transcript.append({'person':'user', 'text': inp})

    params = dict(
        transcript=format_transcript(transcript)
    )

    xs = prompt(**params)
    assert len(xs) == 1
    resp = xs[0].variables['RESPONSE']
    transcript.append({'person':'assistant', 'text': resp})
    print(f'ASSISTANT: {resp}')