predibase / lorax

Multi-LoRA inference server that scales to 1000s of fine-tuned LLMs
https://loraexchange.ai
Apache License 2.0
2.17k stars 143 forks source link

Is there any plan to support dynamic lora for qwen/chatglm models? #101

Open KrisWongz opened 11 months ago

KrisWongz commented 11 months ago

Feature request

Cool job! I have successfully run mulit-lora with llama2-70b. I would like to ask if the author has any plans to support other models, such as qwen, which would be very helpful.

Motivation

null

Your contribution

null

felixstander commented 11 months ago

can support the Alibaba open-source Qwen model will be wonderful

tgaddair commented 11 months ago

Hey @KrisWongz @felixstander, thanks for trying out LoRAX!

Looking at the code for Qwen, it looks pretty similar to Llama. It sounds like the main difference may be the use of a bias in the QKV computation, which shouldn't be a problem.

I can definitely try taking a stab at it and see how it goes.

tgaddair commented 11 months ago

Hey @KrisWongz @felixstander, #103 should add support for Qwen. The base model appears to generate results consistent with the example on Huggingface Hub. Do you have an adapter I can use to test that the adapter loading works as expected?

tgaddair commented 11 months ago

Note that you'll need to run with --trust-remote-code when launching LoRAX as the tokenizer is custom and hosted on HF.

KrisWongz commented 11 months ago

Note that you'll need to run with --trust-remote-code when launching LoRAX as the tokenizer is custom and hosted on HF.

I pulled the latest docker and set --trust-remote-code on startup. Startup code:

sudo docker run --gpus all \ --shm-size 10g \ -p 8081:80 \ -v /home/shaohongen/Temp/Models/Qwen:/data ghcr.io/predibase/lorax:latest \ --model-id /data/Tongyi-Finance-14B-Chat \ --trust-remote-code \

But it still reports an error:

2023-12-06T06:11:56.409578Z INFO lorax_launcher: Args { model_id: "/data/Tongyi-Finance-14B-Chat", adapter_id: "", source: "hub", adapter_source: "hub", revision: None, validation_workers: 2, sharded: None, num_shard: None, quantize: None, dtype: None, trust_remote_code: true, max_concurrent_requests: 128, max_best_of: 2, max_stop_sequences: 4, max_input_length: 1024, max_total_tokens: 2048, waiting_served_ratio: 1.2, max_batch_prefill_tokens: 4096, max_batch_total_tokens: None, max_waiting_tokens: 20, max_active_adapters: 128, adapter_cycle_time_s: 2, hostname: "4d4c7a004768", port: 80, shard_uds_path: "/tmp/lorax-server", master_addr: "localhost", master_port: 29500, huggingface_hub_cache: Some("/data"), weights_cache_override: None, disable_custom_kernels: false, cuda_memory_fraction: 1.0, json_output: false, otlp_endpoint: None, cors_allow_origin: [], watermark_gamma: None, watermark_delta: None, ngrok: false, ngrok_authtoken: None, ngrok_edge: None, env: false, download_only: false } 2023-12-06T06:11:56.409692Z WARN lorax_launcher: trust_remote_code is set. Trusting that model /data/Tongyi-Finance-14B-Chat do not contain malicious code. 2023-12-06T06:11:56.409982Z INFO download: lorax_launcher: Starting download process. 2023-12-06T06:12:04.991708Z WARN lorax_launcher: cli.py:143 No safetensors weights found for model /data/Tongyi-Finance-14B-Chat at revision None. Converting PyTorch weights to safetensors.

2023-12-06T06:12:25.164707Z ERROR download: lorax_launcher: Download encountered an error: Traceback (most recent call last):

File "/opt/conda/bin/lorax-server", line 8, in sys.exit(app())

File "/opt/conda/lib/python3.9/site-packages/lorax_server/cli.py", line 199, in download_weights _download_weights(model_id, revision, extension, auto_convert, source)

File "/opt/conda/lib/python3.9/site-packages/lorax_server/cli.py", line 173, in _download_weights utils.convert_files(local_pt_files, local_st_files, discard_names)

File "/opt/conda/lib/python3.9/site-packages/lorax_server/utils/convert.py", line 112, in convert_files convert_file(pt_file, sf_file, discard_names)

File "/opt/conda/lib/python3.9/site-packages/lorax_server/utils/convert.py", line 71, in convert_file loaded = torch.load(pt_file, map_location="cpu")

File "/opt/conda/lib/python3.9/site-packages/torch/serialization.py", line 809, in load return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)

File "/opt/conda/lib/python3.9/site-packages/torch/serialization.py", line 1172, in _load result = unpickler.load()

File "/opt/conda/lib/python3.9/site-packages/torch/serialization.py", line 1142, in persistent_load typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))

File "/opt/conda/lib/python3.9/site-packages/torch/serialization.py", line 1112, in load_tensor storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage

File "/opt/conda/lib/python3.9/site-packages/transformers/dynamic_module_utils.py", line 579, in _raise_timeout_error raise ValueError(

ValueError: Loading this model requires you to execute custom code contained in the model repository on your local machine. Please set the option trust_remote_code=True to permit loading of this model.

Error: DownloadError

felixstander commented 11 months ago

Hey @KrisWongz @felixstander, #103 should add support for Qwen. The base model appears to generate results consistent with the example on Huggingface Hub. Do you have an adapter I can use to test that the adapter loading works as expected?

Really appreciate your work! Haven't tested yet, but I will upload couple of my fine-tuned adapters to huggingface hub for your guys to test soon.

tgaddair commented 11 months ago

Thanks @felixstander!

@KrisWongz it looks like the model weights .bin file is trying to execute some code on deserialization. I wasn't able to repro this using the base model from Huggingface here: https://huggingface.co/jxy/Tongyi-Finance-14B-Chat.

This is one of the issues with pickle, though, which is it can do unpredictable things like this. Can you try converting the weights to safetensors format and trying again?

felixstander commented 11 months ago

Does Lorax support Qwen-4bit-gptq version without the need of flash attention v2?

As far as I can see, All models you support now are built on top of flash attention by default. Unfortunately, some of our GPUs for inference are still V100, which is not liked by flash attn.:(

image
tgaddair commented 11 months ago

Currently we rely on flash attention, but we can definitely explore alternatives, like falling back to paged attention during prefill if needed.

felixstander commented 11 months ago

@tgaddair I'm testing with Qwen-14-gptq-int4 on RTX3090 right now, My launch parameters are as follow: lorax-launcher --model-id /root/autodl-tmp/Qwen-14B-Chat-Int4 --quantize gptq --trust-remote-code --port 6006 But I got the following error: 2023-12-07T06:58:35.659747Z INFO lorax_launcher: server.py:263 Server started at unix:///tmp/lorax-server-0

2023-12-07T06:58:35.748970Z INFO shard-manager: lorax_launcher: Shard ready in 8.511468629s rank=0 2023-12-07T06:58:35.843122Z INFO lorax_launcher: Starting Webserver 2023-12-07T06:58:35.855480Z WARN lorax_router: router/src/main.rs:169: Could not find a fast tokenizer implementation for /root/autodl-tmp/Qwen-14B-Chat-Int4 2023-12-07T06:58:35.855561Z WARN lorax_router: router/src/main.rs:172: Rust input length validation and truncation is disabled 2023-12-07T06:58:35.855586Z WARN lorax_router: router/src/main.rs:197: no pipeline tag found for model /root/autodl-tmp/Qwen-14B-Chat-Int4 2023-12-07T06:58:35.876373Z INFO lorax_router: router/src/main.rs:216: Warming up model 2023-12-07T06:58:37.188252Z ERROR lorax_launcher: interceptor.py:41 Method Warmup encountered an error. Traceback (most recent call last): File "/root/lorax/server/lorax_server/models/flash_causallm.py", line 843, in warmup , batch = self.generate_token(batch) File "/root/miniconda3/envs/lorax/lib/python3.9/contextlib.py", line 79, in inner return func(*args, kwds) File "/root/lorax/server/lorax_server/models/flash_causal_lm.py", line 939, in generate_token raise e File "/root/lorax/server/lorax_server/models/flash_causal_lm.py", line 936, in generate_token out = self.forward(batch, adapter_data) File "/root/lorax/server/lorax_server/models/flash_causal_lm.py", line 895, in forward return self.model.forward( File "/root/lorax/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py", line 471, in forward hidden_states = self.transformer( File "/root/miniconda3/envs/lorax/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/root/lorax/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py", line 427, in forward hidden_states, residual = layer( File "/root/miniconda3/envs/lorax/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/root/lorax/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py", line 352, in forward attn_output = self.attn( File "/root/miniconda3/envs/lorax/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/root/lorax/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py", line 235, in forward paged_attn.reshape_and_cache( File "/root/lorax/server/lorax_server/utils/paged_attn.py", line 23, in reshape_and_cache cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) RuntimeError: expected scalar type Int but found Long

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/root/miniconda3/envs/lorax/bin/lorax-server", line 8, in sys.exit(app()) File "/root/miniconda3/envs/lorax/lib/python3.9/site-packages/typer/main.py", line 311, in call return get_command(self)(*args, kwargs) File "/root/miniconda3/envs/lorax/lib/python3.9/site-packages/click/core.py", line 1130, in call return self.main(args, kwargs) File "/root/miniconda3/envs/lorax/lib/python3.9/site-packages/typer/core.py", line 778, in main return _main( File "/root/miniconda3/envs/lorax/lib/python3.9/site-packages/typer/core.py", line 216, in _main rv = self.invoke(ctx) File "/root/miniconda3/envs/lorax/lib/python3.9/site-packages/click/core.py", line 1657, in invoke return _process_result(sub_ctx.command.invoke(sub_ctx)) File "/root/miniconda3/envs/lorax/lib/python3.9/site-packages/click/core.py", line 1404, in invoke return ctx.invoke(self.callback, ctx.params) File "/root/miniconda3/envs/lorax/lib/python3.9/site-packages/click/core.py", line 760, in invoke return __callback(args, kwargs) File "/root/miniconda3/envs/lorax/lib/python3.9/site-packages/typer/main.py", line 683, in wrapper return callback(*use_params) # type: ignore File "/root/lorax/server/lorax_server/cli.py", line 83, in serve server.serve( File "/root/lorax/server/lorax_server/server.py", line 271, in serve asyncio.run( File "/root/miniconda3/envs/lorax/lib/python3.9/asyncio/runners.py", line 44, in run return loop.run_until_complete(main) File "/root/miniconda3/envs/lorax/lib/python3.9/asyncio/base_events.py", line 634, in run_until_complete self.run_forever() File "/root/miniconda3/envs/lorax/lib/python3.9/asyncio/base_events.py", line 601, in run_forever self._run_once() File "/root/miniconda3/envs/lorax/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once handle._run() File "/root/miniconda3/envs/lorax/lib/python3.9/asyncio/events.py", line 80, in _run self._context.run(self._callback, self._args) File "/root/miniconda3/envs/lorax/lib/python3.9/site-packages/grpc_interceptor/server.py", line 159, in invoke_intercept_method return await self.intercept(

File "/root/lorax/server/lorax_server/interceptor.py", line 38, in intercept return await response File "/root/miniconda3/envs/lorax/lib/python3.9/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 82, in _unary_interceptor raise error File "/root/miniconda3/envs/lorax/lib/python3.9/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 73, in _unary_interceptor return await behavior(request_or_iterator, context) File "/root/lorax/server/lorax_server/server.py", line 74, in Warmup max_supported_total_tokens = self.model.warmup(batch) File "/root/lorax/server/lorax_server/models/flash_causal_lm.py", line 845, in warmup raise RuntimeError( RuntimeError: Not enough memory to handle 4096 prefill tokens. You need to decrease --max-batch-prefill-tokens

2023-12-07T06:58:37.188701Z ERROR warmup{max_input_length=1024 max_prefill_tokens=4096}:warmup: lorax_client: router/client/src/lib.rs:33: Server error: Not enough memory to handle 4096 prefill tokens. You need to decrease --max-batch-prefill-tokens Error: Warmup(Generation("Not enough memory to handle 4096 prefill tokens. You need to decrease --max-batch-prefill-tokens")) 2023-12-07T06:58:37.245694Z ERROR lorax_launcher: Webserver Crashed 2023-12-07T06:58:37.245719Z INFO lorax_launcher: Shutting down shards 2023-12-07T06:58:37.510784Z INFO shard-manager: lorax_launcher: Shard terminated rank=0 Error: WebserverFailed

felixstander commented 11 months ago

Setting max_prefill_tokens to be the same with max_input_length isn't working. 2023-12-07T07:09:11.100552Z ERROR warmup{max_input_length=1024 max_prefill_tokens=1024}:warmup: lorax_client: router/client/src/lib.rs:33: Server error: Not enough memory to handle 1024 prefill tokens. You need to decrease --max-batch-prefill-tokens Error: Warmup(Generation("Not enough memory to handle 1024 prefill tokens. You need to decrease --max-batch-prefill-tokens")) 2023-12-07T07:09:11.154507Z ERROR lorax_launcher: Webserver Crashed 2023-12-07T07:09:11.154534Z INFO lorax_launcher: Shutting down shards 2023-12-07T07:09:11.441994Z INFO shard-manager: lorax_launcher: Shard terminated rank=0 Error: WebserverFailed

felixstander commented 11 months ago

Even change the max-input-length and max-batch-prefill-tokens down to 100 tokens, it still pops up the Not Enough Memory warning. And I noticed the memory utilization rate jumps to 100% before it crushes.

image image
KrisWongz commented 11 months ago

Thanks @felixstander!

@KrisWongz it looks like the model weights .bin file is trying to execute some code on deserialization. I wasn't able to repro this using the base model from Huggingface here: https://huggingface.co/jxy/Tongyi-Finance-14B-Chat.

This is one of the issues with pickle, though, which is it can do unpredictable things like this. Can you try converting the weights to safetensors format and trying again?

Thanks a lot! I successfully ran qwen and multi lora. My solution is to convert qwen to .safetensors locally.

But there is currently a small problem that I have been dealing with for a long time. I use qwen to reason and cannot stop until the maximum length limit every time. I judge that this may be due to stop words. In 'generate()', I found that there is a 'stop=[]' parameter, but I got an error when I tried to enter it.

client.generate(prompt, max_new_tokens=32,temperature=0.7,stop=["<|im_end|>"]).generated_text

error:

Traceback (most recent call last): File "/home/shaohongen/Temp/WZ_test/lorax/test_lorax_qwen.py", line 20, in print(client.generate(prompt, max_new_tokens=32,temperature=0.7,stop=["<|im_end|>"]).generated_text) TypeError: generate() got an unexpected keyword argument 'stop'

But I tried other parameters successfully, except 'stop':

generate{parameters=GenerateParameters { adapter_id: None, adapter_source: None, best_of: None, temperature: Some(0.7), repetition_penalty: None, top_k: None, top_p: None, typical_p: None, do_sample: false, max_new_tokens: 32, return_full_text: Some(false), stop: [], truncate: None, watermark: false, details: true, decoder_input_details: false, seed: None }

By the way, 'do_sample=True' will make output better, but not every time.

tgaddair commented 11 months ago

Hey @KrisWongz, can you try using the param stop_sequences instead of stop?

Example:

client.generate(prompt, max_new_tokens=32, temperature=0.7, stop_sequences=["<|im_end|>"]).generated_text
tgaddair commented 11 months ago

Hey @felixstander, it looks like the error about decreasing the max batch size is misleading, the actual error here is:

File "/root/lorax/server/lorax_server/utils/paged_attn.py", line 23, in reshape_and_cache
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
RuntimeError: expected scalar type Int but found Long

Let me see if I can reproduce this error on my side.

tgaddair commented 11 months ago

Hey @felixstander, I wasn't able to reproduce the error using the most recent Docker image. Can you try pulling the latest Docker and trying again? There were a couple recent changes for GPT-Q that may have fixed this issue.

KrisWongz commented 11 months ago

Hey @KrisWongz, can you try using the param stop_sequences instead of stop?

Example:

client.generate(prompt, max_new_tokens=32, temperature=0.7, stop_sequences=["<|im_end|>"]).generated_text

It works on the base model, but it seems useless after adding the lora adapter. It may be a problem with the template settings when I fine-tuned it. But I ran it successfully under qwen's original web_demo. I'll keep trying, thanks for your help!

thincal commented 8 months ago

Any plan for ChatGLM model support ? thanks.

tgaddair commented 8 months ago

Hey @thincal, we can definitely add ChatGLM support. I can create a separate issue to track that.

tgaddair commented 8 months ago

280

thincal commented 8 months ago

@tgaddair it seems that qwen model type is qwen2 now, so what's the supported version in current implementation of lorax ? Ref: https://modelscope.cn/models/qwen/Qwen1.5-14B-Chat-AWQ/file/view/master/config.json