eth-sri / lmql

A language for constraint-guided and efficient LLM programming.
https://lmql.ai
Apache License 2.0
3.69k stars 199 forks source link

Speed and Optimisation of Inference #233

Open ambroser53 opened 1 year ago

ambroser53 commented 1 year ago

I'm using LMQL as the front-end for a big project that requires a lot of inference. I am using an A100 80GB but finding inference to be incredibly slow. I get roughly 50 samples through every 3 hours with a 13b LLaMA2-chat model which seems exceptionally slow, especially since it is loaded with 4bit quantisation and brain float. Its not a particularly long prompt either. I'm using the following decoder: beam_sample(n=2, temperature=0.7, top_k=0, repetition_penalty=1.1, top_p=0.7)

I'm wondering what techniques there are to speed up inference, for example, can async be used to have the equivalent of "batches" when used locally? Is there any threading considerations that must be made to minimise latency from computation within LMQL itself? Is there anything else I can do? Is there ways of splitting across multiple GPUs to improve inference speed like with the device map?

Thanks for your help.

lbeurerkellner commented 1 year ago

Hi there. There are several knobs to turn and play with to get better throughput. The performance you are seeing definitely can be improved by a lot:

Let me know how these different strategies work for you. Happy to also look and profile some more specific query code, if you can share something :)

ambroser53 commented 1 year ago

Thanks for your response. I have implemented everything you said into my project but still finding LMQL in async to be creating a lot of problems and I cannot get it to work. Here is a link to my repository for you to run yourself as it is a very specific case. It has separate client and server batch sizes with Semaphore and then batch_size in lmql.model specifications (which I'm assuming allows LMTP to carry out its own queuing and batching?). The semaphores batch the binary_tournament_evolve which itself can have anywhere from 1 to 2N + 1 queries. I am not familiar with asynchronous programming but I am unsure if this is a good solution in and of itself.

Essentially there are a bunch of setbacks stopping this from working effectively:

Thank you for your help. Since this code base requires so much querying if I could be sure that I'm making the right choices, or at least have solid logging I could set it off with confidence.

lbeurerkellner commented 1 year ago

Thanks a lot for the repo, I will definitely have a closer look. It would be awesome to do a bit of performance engineering/profiling work here, as a case study, so we can learn and see what we can improve, and what performance tricks do work.

Once I had a proper look, I can also give you a more detailed response to the different points you brought up.

Regarding 2, it can indeed happen that two concurrent workloads that are quite different in nature (e.g. argmax vs. beam_var decoding), conflict at a batching/scheduler level, leading to suboptimal scheduling and performance.

Re 3, in one process, all calls should share the same in-process instance. Is this just a suspicion or did you observe in-process models being loaded twice, e.g., when monitoring the GPU use?

Re profiling, 0.7 added some basic tracing via inference certificates. This allows you to see the individual calls that go out to the backend. You can also run with verbose=True, to see generate() request logging. We also have full opentelemetry support in the works for end-to-end tracing with timing information, but it is not fully ready yet.

ambroser53 commented 1 year ago

Thanks for the quick reply. Re 2, so if all different workloads share the same decoding method then they should naturally have pretty well optimised scheduling and performance? Are there any other factors that may effect this like number of constraints within queries, length of queries, etc?

Re 3, This was just a suspicion, as I would often send off many tasks at the scripts opening and receive nothing back from any of them. Also I just generally found GPU monitoring to be very inconsistent when compared to regular pytorch/transformers use.

ambroser53 commented 1 year ago

OK so because I was optimising so much and running the repo in absolutely tiny conditions where it must have performed quickly I think this isn't a speed issue but actually a bug, potentially within the LMQL code. Essentially I think it is getting caught in some kind of async infinite loop. If you check out the debug branch of the repo you will see the most recent code. But essentially I have isolated it to the run_fitness_test method where, using my debugger I found this stack trace.

untitled

I don't think this is supposed to happen. Should I make a new issue?

ambroser53 commented 1 year ago

Running with the toy environment with EleutherAI/pythia-70m on cpu with nest_asyncio.apply() gave the following new error:

To avoid this, please make sure to not call lmql.model(..., inprocess=True) on the top level of your script, but only inside functions or the __main__ block.Error: Detected an access to a lmql.model(..., inprocess=True)/local:<MODEL> inside the multiprocessing worker process itself. This is not supported and may lead to unexpected behavior.

I definitely do not have lmql.model on the top level of the script, infact its only defined within a string passed to lmql.query. Does this mean that perhaps LMQL has a problem with double dipping with the asyncio event loop and allowing reentrance. Using nest_asyncio.apply() along with torch.multiprocessing.set_start_method('spawn'), although it doesn't fix this error does allow the program to run on CUDA devices.

ambroser53 commented 1 year ago

Going back to speed and optimisation since it now seems to be working fine with these changes, is there any explanation for why one might be getting a GPU memory usage graph like this?

image Looking specifically at the long green and brown plots.

This is with a server side batchsize of 4 and a client size batch size of 50. It seems as if LMQL takes a while to build up to its max utilisation, is this correct? If so what would be the reason for this?

lbeurerkellner commented 1 year ago

Hi there, I haven't found time to get deeper into the repo, sorry about that. Still, it's great to hear that things are moving forward. Thanks a lot for the smaller reproducible example, I will try to run this. A couple of things:

ambroser53 commented 1 year ago

Yes those changes fixed the issue and it seems to run fine now despite that error message.

With memory use I'm just trying fully utilise all the 80gb vram to improve performance, no particular OOM situations yet. Thanks for bringing that to my attention, I'll try and crank it as high as I can and see how much of a speed boost I can get.

A retry mechanic on OOM would be very good. Will there be much of this parallelisation effect on performance for sample decoding with small N values?

ambroser53 commented 1 year ago

I've attempted to use flash attention 2 to see whether that improves speed and found that it actually decreased it:

IMG-20231020-WA0002

The bottom is without flash attention and the top is with it. Would there be any reason that this is happening? Is there a way to use LMQL that utilises its benefits?

ambroser53 commented 1 year ago

We did a test with flash attention and a batch size of 1 and found that although the inference time for single samples/mean time to generate was longer, it actually managed to complete full runs of the codebase much faster (completed cycle in 148 minutes whereas a batch size of 10 didnt finish even at 178 minutes).

Could it be that the internal batching of LMQL isn't working as intended? Does LMQL use padding when batching? Because padding is meant to cause big slowdowns when using flash attention. Where can I find the actual point of inference within the LMQL code?

lbeurerkellner commented 1 year ago

Thanks for your continued investigation.

An update on my end: I now have a setup with the debug branch with Pythia and am doing test runs on my machine. I am running with

python LMQL_prompt_breeder.py --task answerability --is_async --fitness_test_suite ./data/fitness_test_suites/squad_reqs_fitness_test_suite.json --test_CoT --day_care_dir data/lmql_specifications/day_care/ --client_batch_size 16 --model_name_or_path EleutherAI/pythia-70m

Together with a inference server launched via:

lmql serve-model EleutherAI/pythia-70m --cuda --batch_size 32 --busy_logging

I changed the lmql.model(...) calls in breeder.lmql to use the server as opposed to inprocess=True. The --busy_logging option is new and only available on a current development branch lmtp-cancel. It prints idle and streaming status (including tok/s), to the console of the inference server process. With this, it is much easier to see, if/when inference is happening and at what throughput.

With the call above, I am so far observing some idle times, but mostly consistent token streaming/inference without major pauses. Maybe you can share a breeder command that triggers some odd behaviour for you, e.g. I suppose you are using beam_sample or a more custom decoder? Also, it would be helpful to know at what stage of process execution things seem to slow down, e.g. over time or right at the beginning?

Re flashattention/padding, we do actually rely on padding to batch together requests of different sequence length in the backend. Batching happens here, and padding is done via this method. The resulting batch of sequences is then passed to the transformers backend in this file.

I hope this is helpful to track down where exactly padding does happen. Please also let me know about the breeder commands you are running, so I can get a good view on the state of things, to start some profiling work.