predibase / lorax

Multi-LoRA inference server that scales to 1000s of fine-tuned LLMs
https://loraexchange.ai
Apache License 2.0
2.18k stars 143 forks source link

Does lorax currently support GPT2 finetuned adapters? #84

Open abhijithnair1 opened 11 months ago

abhijithnair1 commented 11 months ago

System Info

lorax:latest

Information

Tasks

Reproduction

@tgaddair I have few adapters finetuned using GPT2 as base model,

Architecture of GPT2:

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

The adapters are finetuned with "c_attn, c_proj" layer, does lorax currently support it?

Expected behavior

Question about compatibility.

tgaddair commented 11 months ago

Hey @abhijithnair1, we don't yet support GPT-2, but we were just discussing adding this model. Happy to take a stab at adding this model this week.

abhijithnair1 commented 11 months ago

@tgaddair That's actually really good news for us. We have few use-cases that are latency constrained so we went with GPT-2. I think having GPT-2 support will be a big win for us. I will be looking out for an update.

abhijithnair1 commented 11 months ago

@tgaddair Is there any update on the progress of GPT2 Addition?

tgaddair commented 11 months ago

Hey @abhijithnair1, good news! We just landed a couple PRs that enable GPT2:

Let us (in particular, @geoffreyangus) know if you run into any issues!

abhijithnair1 commented 11 months ago

@tgaddair Thank you for all the work. This is really good news. Is the latest changes available in the :latest docker image?

tgaddair commented 11 months ago

@abhijithnair1 yes, should be ready to use!

abhijithnair1 commented 11 months ago

@tgaddair @geoffreyangus I am seeing some difference in the inference time? I finetuned a tiny llama adapter with qlora and was getting around 150ms average for generating 10 tokens (input was 100 tokens), but for the same prompt I am getting 200-210ms now on a much smaller model like gpt2-medium and here is the adapter used https://huggingface.co/abhipn/gpt2-medium-finetuned-qlora. (gpu used is A100 40GB). We are looking for an ideal latency of 80-90ms

Do you what could be the reason?

tgaddair commented 11 months ago

Hey @abhijithnair1, could you share the script you used to generate 10 tokens with 150ms latency? 80-80ms should be doable, I'd be happy to dig into this some more.

tgaddair commented 11 months ago

@abhijithnair1, for reference, here's a quick script I ran to sanity check this:

import time
from lorax import Client

endpoint_url = "http://127.0.0.1:8080"
client = Client(endpoint_url, timeout=30)

prompt = "Hello, I'm a language model, "

start_t = time.time()
response = client.generate(prompt, max_new_tokens=10)
print(time.time() - start_t)

print(response.generated_text)

On an A100 with 40GB VRAM, this took 103ms with gpt2-medium.

abhijithnair1 commented 11 months ago

@tgaddair The latency that I mentioned also includes swapping of the models etc.,

adapters_repeat = [array of adapters for gpt2 medium]
for adapter_id in adapters_repeat:
    generated_text = client.generate(prompt, do_sample=False, max_new_tokens=10, adapter_id=adapter_id, adapter_source="local", best_of=1,).generated_text

This was taking around 200-250 ms. Where as tinyllama qlora adapters it was 100-150ms. My assumption was that GPT2 medium is much smaller that tinyllama this should have been faster right?

Can you try a list of adapters or just repeat the same adapters in a loop or something and verify the time that includes loading the adapter etc.,?

FYI.. we loaded the adapters from file system instead of hub.

tgaddair commented 11 months ago

Hey @abhijithnair1, happy to run that test. Do you have any adapters I could use for testing? If not, could you tell me the ranks and target modules of your adapters, as it can have an effect on performance. Also, how many adapters are in your list?

My expectation would be that you will notice an increase in latency the first time each adapter is loaded, but subsequent requests should be nearly identical in latency to querying the base model (as there is no additional swapping required at that point). But again, happy to run some tests to verify that further.

abhijithnair1 commented 10 months ago

@tgaddair We trained two adapters on GPT2-medium on private company data, and these are the lora parameters we used,

target_modules: [
 "c_proj",
 "c_fc",
 "c_attn"]
 
 r: 16
 lora_alpha: 32

For TinyLlama also we fintuned two adapters these are the parameters we used for both:

target_modules: ['q_proj', 'v_proj', 'k_proj', 'o_proj']

Here are some additional requirements we had:

Input prompt length: 100 tokens
output length: 10 tokens

Firstly we took adapters_list as (adapter1, adapter2, adapter1, adapter2 and so on...) and then made predictions like,

for adapter_id in adapter_list:
    client.generate(prompt, do_sample=False, max_new_tokens=10, adapter_id=adapter_id, adapter_source="local", best_of=1,).generated_text

Used the above code snippet we calculated the time taken for generating predictions. We were hoping GPT2 adapters might have been much faster than tinyllama.

tgaddair commented 10 months ago

Thanks @abhijithnair1, as an update, I found some unnecessary mallocs that were contributing about 50ms of overhead, which are being removed in #139.

From my tests, the difference in latency is attributable to the cost of applying the LoRA layers, not to any adapter switching time (which should be negligible).

There are a few more things I will look to try this week, but hopefully this will get you closer to your target latency.

abhijithnair1 commented 10 months ago

@tgaddair An improvement of 50ms is noticeable in latency sensitive scenarios. Thank you so much for all your contributions to this ticket. I am looking forward to the final release with all the improvements.

abhijithnair1 commented 10 months ago

@tgaddair Tried with :latest image, still seeing the latency is above 180-200ms for gpt2-medium (getting almost similar latency with 3-4x larger models). This is the prompt I'm using.

You are a helpful assistant providing detailed answers to multiplication questions. Ensure you provide a thorough explanation using the long multiplication method for the math problem given. Make sure the answer is accurate. What is the multiplication of 524 * 192?<|endoftext|>

client.generate(prompt, do_sample=False, max_new_tokens=10, adapter_id=adapter_id, adapter_source="local", best_of=1)

np.mean(time_taken) # 187 ms
np.min(time_taken) # 173 ms

I see a 15-20ms improvement from my end.

tgaddair commented 10 months ago

Hey @abhijithnair1, your prompt in this case is longer than the one I was using, so that may contribute to the smaller decrease in latency (less of the time was spent on small operations like mallocs).

I'm currently working on a bigger series of changes involving CUDA graphs that from early testing have a shown as much as an additional 30-50ms decrease in latency for 10 tokens, so I'll let you know when that's ready to test out.

tgaddair commented 10 months ago

Hey @abhijithnair1, quick update on this. The CUDA graph PR should be coming in the next day or two. From my testing (on A100), I observed that latency with batch size 1 and 1 adapter went from 147ms to 93ms, so about a 54ms reduction in latency.

abhijithnair1 commented 10 months ago

@tgaddair That's a huge improvement in latency. I'll be looking forward for the update and testing things out.

tgaddair commented 10 months ago

Here's the PR: https://github.com/predibase/lorax/pull/154

There are still a few small things to add and test before it's ready to use, but my hope is we can land it today or tomorrow.