predibase / lorax

Multi-LoRA inference server that scales to 1000s of fine-tuned LLMs
https://loraexchange.ai
Apache License 2.0
2.19k stars 143 forks source link

Some error records and questions #115

Open KrisWongz opened 11 months ago

KrisWongz commented 11 months ago

System Info

Docker images: 2023-12-06 GPUs: 2 A40(48g) OS: centos 7

Information

Tasks

Reproduction

null

Expected behavior

I tested three models qwen-14b, yi-34b-chat (llama2 based), and xuanyuan-70b-chat (llama2 based). Each model prepared 2-3 lora adapter and encountered some problems. Xuanyuan can run completely normally. All the following questions are based on Qwen and Yi.

  1. Without adding lora adapter, the output will reach max new length without adding stopwords, but it will not be nonsense. Adding stopwords can output normally. But after adding lora adapter, it will output to max_new_length and make nonsense. Consider that may be a template configuration issue in fine-tuning .

  2. When num_share is set to 2 (two GPUs), pre_fill_length=4096 will cause insufficient memory, and even 1024 will report the error (Qwen). num_share = 1 can run normally.

    RuntimeError: Not enough memory to handle 1028 prefill tokens. You need to decrease --max-batch-prefill-tokens 2023-12-08T10:14:01.820643Z ERROR warmup{max_input_length=1024 max_prefill_tokens=1028}:warmup: lorax_client: router/client/src/lib.rs:33: Server error: Not enough memory to handle 1028 prefill tokens. You need to decrease --max-batch-prefill-tokens Error: Warmup(Generation("Not enough memory to handle 1028 prefill tokens. You need to decrease --max-batch-prefill-tokens")) 2023-12-08T10:14:01.894611Z ERROR lorax_launcher: Webserver Crashed 2023-12-08T10:14:01.894628Z INFO lorax_launcher: Shutting down shards 2023-12-08T10:14:02.427130Z INFO shard-manager: lorax_launcher: Shard terminated rank=1 2023-12-08T10:14:03.432023Z INFO shard-manager: lorax_launcher: Shard terminated rank=0 Error: WebserverFailed

  3. As long as max_new_token is higher than 200, the connection will fail. set under 200 can run normally. Settings :

    Args { model_id: "/data/yi-34b-chat", adapter_id: "", source: "hub", adapter_source: "hub", revision: None, validation_workers: 2, sharded: None, num_shard: None, quantize: Some(BitsandbytesNF4), dtype: None, trust_remote_code: false, max_concurrent_requests: 128, max_best_of: 2, max_stop_sequences: 4, max_input_length: 1024, max_total_tokens: 2048, waiting_served_ratio: 1.2, max_batch_prefill_tokens: 4096, max_batch_total_tokens: None, max_waiting_tokens: 20, max_active_adapters: 128, adapter_cycle_time_s: 2, hostname: "c92d36636b23", port: 80, shard_uds_path: "/tmp/lorax-server", master_addr: "localhost", master_port: 29500, huggingface_hub_cache: Some("/data"), weights_cache_override: None, disable_custom_kernels: false, cuda_memory_fraction: 1.0, json_output: false, otlp_endpoint: None, cors_allow_origin: [], watermark_gamma: None, watermark_delta: None, ngrok: false, ngrok_authtoken: None, ngrok_edge: None, env: false, download_only: false }

Post:

prompt = "<|im_start|>user\Tell me a story<|im_end|>\n<|im_start|>assistant\n" adapter_id = "/data/chat-int8-3-epoch-1024-manual_2360-self-5000_1207-1" print(client.generate(prompt, max_new_tokens=300,temperature=0.8, do_sample=True, stop_sequences=["<|im_end|>"], adapter_id=adapter_id).generated_text)

Error :

Traceback (most recent call last): File "/home/shaohongen/Temp/WZ_test/lorax/test_lorax_yi.py", line 9, in print(client.generate(prompt, max_new_tokens=300,temperature=0.8, do_sample=True, stop_sequences=["<|im_end|>"], adapter_id=adapter_id).generated_text) File "/home/shaohongen/miniconda3/envs/slora/lib/python3.9/site-packages/lorax/client.py", line 148, in generate resp = requests.post( File "/home/shaohongen/miniconda3/envs/slora/lib/python3.9/site-packages/requests/api.py", line 115, in post return request("post", url, data=data, json=json, kwargs) File "/home/shaohongen/miniconda3/envs/slora/lib/python3.9/site-packages/requests/api.py", line 59, in request return session.request(method=method, url=url, kwargs) File "/home/shaohongen/miniconda3/envs/slora/lib/python3.9/site-packages/requests/sessions.py", line 589, in request resp = self.send(prep, send_kwargs) File "/home/shaohongen/miniconda3/envs/slora/lib/python3.9/site-packages/requests/sessions.py", line 703, in send r = adapter.send(request, kwargs) File "/home/shaohongen/miniconda3/envs/slora/lib/python3.9/site-packages/requests/adapters.py", line 532, in send raise ReadTimeout(e, request=request) requests.exceptions.ReadTimeout: HTTPConnectionPool(host='127.0.0.1', port=8081): Read timed out. (read timeout=10)

  1. Question: Three models of different magnitudes (14b, 34b, 70b), under int 4 quantization, actually occupy the same graphics memory, about 44GB. If num_share=2, the memory occupation is 37G*2 in the case of int-4, which may be the reason for the max_new_token limit? And why the same graphics memory.
tgaddair commented 11 months ago

Hey @KrisWongz, let me take a stab at answering your questions:

  1. It may be an issue with how the LoRA adapter was fine-tuned. If it didn't properly include the stop token during training, then it might have caused the model to forget to generate it.
  2. For this error, was there anything in the stack trace above the error about max prefill tokens? Sometime this error is misleading and covers up a different error. Please feel free to share the full stack trace and I will take a look.
  3. This is probably happening because the client timeout is set to 10s. You can increase it to prevent the connection from closing like: Client(..., timeout=60).
  4. How are you determining the amount of memory used? If you're looking at nvidia-smi, it will usually show nearly all the GPU memory being used, because LoRAX will try and consume as much memory as it can for batching unless you set cuda_memory_fraction.
tgaddair commented 11 months ago

Hey @KrisWongz, I found the issue with tensor parallelism. I have a PR up to fix it here: #120.

There may be some additional issues with how LoRA weights are loaded for Qwen using tensor parallelism. I'll continue to investigate that after this lands, but for now are you able to make progress without tensor parallelism?

KrisWongz commented 11 months ago

Hey @KrisWongz, I found the issue with tensor parallelism. I have a PR up to fix it here: #120.

There may be some additional issues with how LoRA weights are loaded for Qwen using tensor parallelism. I'll continue to investigate that after this lands, but for now are you able to make progress without tensor parallelism?

I am very grateful that you can be so patient and pay attention to my questions. And :

  1. I pulled the latest code today, and qwen's stop_word ran successfully. But Yi-34b-chat still doesn't work (it can be used normally in the official website demo).

  2. This problem still exists in the old code and the new code, when setting num_share = 2(qwen):

    2023-12-11T10:22:22.284794Z INFO lorax_launcher: Args { model_id: "/data/Tongyi-Finance-14B-Chat", adapter_id: "", source: "hub", adapter_source: "hub", revision: None, validation_workers: 2, sharded: None, num_shard: Some(2), quantize: Some(BitsandbytesNF4), dtype: None, trust_remote_code: true, max_concurrent_requests: 128, max_best_of: 2, max_stop_sequences: 4, max_input_length: 1024, max_total_tokens: 2048, waiting_served_ratio: 1.2, max_batch_prefill_tokens: 4096, max_batch_total_tokens: None, max_waiting_tokens: 20, max_active_adapters: 128, adapter_cycle_time_s: 2, hostname: "9d59555582a1", port: 80, shard_uds_path: "/tmp/lorax-server", master_addr: "localhost", master_port: 29500, huggingface_hub_cache: Some("/data"), weights_cache_override: None, disable_custom_kernels: false, cuda_memory_fraction: 1.0, json_output: false, otlp_endpoint: None, cors_allow_origin: [], watermark_gamma: None, watermark_delta: None, ngrok: false, ngrok_authtoken: None, ngrok_edge: None, env: false, download_only: false } 2023-12-11T10:22:22.284909Z WARN lorax_launcher: trust_remote_code is set. Trusting that model /data/Tongyi-Finance-14B-Chat do not contain malicious code. 2023-12-11T10:22:22.284923Z INFO lorax_launcher: Sharding model on 2 processes 2023-12-11T10:22:22.285260Z INFO download: lorax_launcher: Starting download process. 2023-12-11T10:22:26.040324Z INFO lorax_launcher: cli.py:103 Files are already present on the host. Skipping download. 2023-12-11T10:22:26.695029Z INFO download: lorax_launcher: Successfully downloaded weights. 2023-12-11T10:22:26.695262Z INFO shard-manager: lorax_launcher: Starting shard rank=0 2023-12-11T10:22:26.695285Z INFO shard-manager: lorax_launcher: Starting shard rank=1 2023-12-11T10:22:36.712161Z INFO shard-manager: lorax_launcher: Waiting for shard to be ready... rank=0 2023-12-11T10:22:36.712245Z INFO shard-manager: lorax_launcher: Waiting for shard to be ready... rank=1 2023-12-11T10:22:36.829060Z INFO lorax_launcher: server.py:263 Server started at unix:///tmp/lorax-server-1 2023-12-11T10:22:36.912865Z INFO shard-manager: lorax_launcher: Shard ready in 10.216895746s rank=1 2023-12-11T10:22:36.994271Z INFO lorax_launcher: server.py:263 Server started at unix:///tmp/lorax-server-0 2023-12-11T10:22:37.013098Z INFO shard-manager: lorax_launcher: Shard ready in 10.317219386s rank=0 2023-12-11T10:22:37.112131Z INFO lorax_launcher: Starting Webserver 2023-12-11T10:22:37.122407Z WARN lorax_router: router/src/main.rs:169: Could not find a fast tokenizer implementation for /data/Tongyi-Finance-14B-Chat 2023-12-11T10:22:37.122433Z WARN lorax_router: router/src/main.rs:172: Rust input length validation and truncation is disabled 2023-12-11T10:22:37.122439Z WARN lorax_router: router/src/main.rs:197: no pipeline tag found for model /data/Tongyi-Finance-14B-Chat 2023-12-11T10:22:37.143239Z INFO lorax_router: router/src/main.rs:216: Warming up model 2023-12-11T10:22:37.341737Z ERROR lorax_launcher: interceptor.py:41 Method Warmup encountered an error. Traceback (most recent call last): File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/flash_causallm.py", line 843, in warmup , batch = self.generate_token(batch) File "/opt/conda/lib/python3.9/contextlib.py", line 79, in inner return func(*args, kwds) File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/flash_causal_lm.py", line 939, in generate_token raise e File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/flash_causal_lm.py", line 936, in generate_token out = self.forward(batch, adapter_data) File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/flash_causal_lm.py", line 895, in forward return self.model.forward( File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/custom_modeling/flash_qwen_modeling.py", line 470, in forward hidden_states = self.transformer( File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/custom_modeling/flash_qwen_modeling.py", line 427, in forward hidden_states, residual = layer( File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/custom_modeling/flash_qwen_modeling.py", line 352, in forward attn_output = self.attn( File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/custom_modeling/flash_qwen_modeling.py", line 222, in forward query, kv = qkv.split( File "/opt/conda/lib/python3.9/site-packages/torch/_tensor.py", line 864, in split return torch._VF.split_with_sizes(self, split_size, dim) RuntimeError: split_with_sizes expects split_sizes to sum exactly to 7680 (input tensor's size at dimension 1), but got split_sizes=[5120, 10240] The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/opt/conda/bin/lorax-server", line 8, in sys.exit(app()) File "/opt/conda/lib/python3.9/site-packages/typer/main.py", line 311, in call return get_command(self)(*args, *kwargs) File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 1157, in call return self.main(args, kwargs) File "/opt/conda/lib/python3.9/site-packages/typer/core.py", line 778, in main return _main( File "/opt/conda/lib/python3.9/site-packages/typer/core.py", line 216, in _main rv = self.invoke(ctx) File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 1688, in invoke return _process_result(sub_ctx.command.invoke(sub_ctx)) File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 1434, in invoke return ctx.invoke(self.callback, ctx.params) File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 783, in invoke return __callback(args, kwargs) File "/opt/conda/lib/python3.9/site-packages/typer/main.py", line 683, in wrapper return callback(use_params) # type: ignore File "/opt/conda/lib/python3.9/site-packages/lorax_server/cli.py", line 84, in serve server.serve( File "/opt/conda/lib/python3.9/site-packages/lorax_server/server.py", line 271, in serve asyncio.run( File "/opt/conda/lib/python3.9/asyncio/runners.py", line 44, in run return loop.run_until_complete(main) File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 634, in run_until_complete self.run_forever() File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 601, in run_forever self._run_once() File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once handle._run() File "/opt/conda/lib/python3.9/asyncio/events.py", line 80, in _run self._context.run(self._callback, self._args) File "/opt/conda/lib/python3.9/site-packages/grpc_interceptor/server.py", line 165, in invoke_intercept_method return await self.intercept( File "/opt/conda/lib/python3.9/site-packages/lorax_server/interceptor.py", line 38, in intercept return await response File "/opt/conda/lib/python3.9/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 82, in _unary_interceptor raise error File "/opt/conda/lib/python3.9/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 73, in _unary_interceptor return await behavior(request_or_iterator, context) File "/opt/conda/lib/python3.9/site-packages/lorax_server/server.py", line 74, in Warmup max_supported_total_tokens = self.model.warmup(batch) File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/flash_causal_lm.py", line 845, in warmup raise RuntimeError( RuntimeError: Not enough memory to handle 4096 prefill tokens. You need to decrease --max-batch-prefill-tokens 2023-12-11T10:22:37.343322Z ERROR warmup{max_input_length=1024 max_prefill_tokens=4096}:warmup: lorax_client: router/client/src/lib.rs:33: Server error: Not enough memory to handle 4096 prefill tokens. You need to decrease --max-batch-prefill-tokens 2023-12-11T10:22:37.363703Z ERROR lorax_launcher: interceptor.py:41 Method Warmup encountered an error. Traceback (most recent call last): File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/flash_causallm.py", line 843, in warmup , batch = self.generate_token(batch) File "/opt/conda/lib/python3.9/contextlib.py", line 79, in inner return func(*args, *kwds) File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/flash_causal_lm.py", line 939, in generate_token raise e File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/flash_causal_lm.py", line 936, in generate_token out = self.forward(batch, adapter_data) File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/flash_causal_lm.py", line 895, in forward return self.model.forward( File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/custom_modeling/flash_qwen_modeling.py", line 470, in forward hidden_states = self.transformer( File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/custom_modeling/flash_qwen_modeling.py", line 427, in forward hidden_states, residual = layer( File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/custom_modeling/flash_qwen_modeling.py", line 352, in forward attn_output = self.attn( File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/custom_modeling/flash_qwen_modeling.py", line 222, in forward query, kv = qkv.split( File "/opt/conda/lib/python3.9/site-packages/torch/_tensor.py", line 864, in split return torch._VF.split_with_sizes(self, split_size, dim) RuntimeError: split_with_sizes expects split_sizes to sum exactly to 7680 (input tensor's size at dimension 1), but got split_sizes=[5120, 10240] The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/opt/conda/bin/lorax-server", line 8, in sys.exit(app()) File "/opt/conda/lib/python3.9/site-packages/typer/main.py", line 311, in call return get_command(self)(args, kwargs) File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 1157, in call return self.main(args, kwargs) File "/opt/conda/lib/python3.9/site-packages/typer/core.py", line 778, in main return _main( File "/opt/conda/lib/python3.9/site-packages/typer/core.py", line 216, in _main rv = self.invoke(ctx) File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 1688, in invoke return _process_result(sub_ctx.command.invoke(sub_ctx)) File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 1434, in invoke return ctx.invoke(self.callback, ctx.params) File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 783, in invoke return __callback(args, kwargs) File "/opt/conda/lib/python3.9/site-packages/typer/main.py", line 683, in wrapper return callback(use_params) # type: ignore File "/opt/conda/lib/python3.9/site-packages/lorax_server/cli.py", line 84, in serve server.serve( File "/opt/conda/lib/python3.9/site-packages/lorax_server/server.py", line 271, in serve asyncio.run( File "/opt/conda/lib/python3.9/asyncio/runners.py", line 44, in run return loop.run_until_complete(main) File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 634, in run_until_complete self.run_forever() File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 601, in run_forever self._run_once() File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once handle._run() File "/opt/conda/lib/python3.9/asyncio/events.py", line 80, in _run self._context.run(self._callback, *self._args) File "/opt/conda/lib/python3.9/site-packages/grpc_interceptor/server.py", line 165, in invoke_intercept_method return await self.intercept( File "/opt/conda/lib/python3.9/site-packages/lorax_server/interceptor.py", line 38, in intercept return await response File "/opt/conda/lib/python3.9/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 82, in _unary_interceptor raise error File "/opt/conda/lib/python3.9/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 73, in _unary_interceptor return await behavior(request_or_iterator, context) File "/opt/conda/lib/python3.9/site-packages/lorax_server/server.py", line 74, in Warmup max_supported_total_tokens = self.model.warmup(batch) File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/flash_causal_lm.py", line 845, in warmup raise RuntimeError( RuntimeError: Not enough memory to handle 4096 prefill tokens. You need to decrease --max-batch-prefill-tokens 2023-12-11T10:22:37.365320Z ERROR warmup{max_input_length=1024 max_prefill_tokens=4096}:warmup: lorax_client: router/client/src/lib.rs:33: Server error: Not enough memory to handle 4096 prefill tokens. You need to decrease --max-batch-prefill-tokens Error: Warmup(Generation("Not enough memory to handle 4096 prefill tokens. You need to decrease --max-batch-prefill-tokens")) 2023-12-11T10:22:37.413280Z ERROR lorax_launcher: Webserver Crashed 2023-12-11T10:22:37.413305Z INFO lorax_launcher: Shutting down shards 2023-12-11T10:22:37.600037Z INFO shard-manager: lorax_launcher: Shard terminated rank=1 2023-12-11T10:22:37.673855Z INFO shard-manager: lorax_launcher: Shard terminated rank=0 Error: WebserverFailed

  3. timeout = 60 can solve the problem. By the way, the default timeout still 10 in the latest docker.

    requests.exceptions.ReadTimeout: HTTPConnectionPool(host='127.0.0.1', port=8081): Read timed out. (read timeout=10)

  4. Yes, it's 'nvidia-smi'. I got it.

tgaddair commented 11 months ago

Hey @KrisWongz, I can take a look at the first issue with Yi-34b-chat this week hopefully. Issue (2) should be resolved now as I landed the PR this morning. For the timeout default, that will be updated when I push a new version of the client, but setting the timeout manually should work for now.

I'm still trying to track down some differences in the LoRA outputs between the tensor parallel and single GPU results, but it should be functional (and there are no apparent differences between single and tensor parallel Qwen without adapters).

Thank for your patience as well as we work through these issues!

KrisWongz commented 11 months ago

I will continue to follow up on lorax and submit problems encountered during use. Respect @tgaddair.