google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
19 stars 12 forks source link

Empty response returned for prompt responses when using run_server_with_ray.py and batch_size > 1 #137

Open richardsliu opened 1 week ago

richardsliu commented 1 week ago

Sending multiple prompts to the server, only the first prompt is able to return any results. Requests after the first one would only return an empty response.

I've tried 3 different ways to bring up the server (all using interleave singlehost on a TPU v4):

python run_interactive.py --size=7b --model_name=llama-2 --batch_size=32 --max_cache_length=2048 --tokenizer_path=/home/ray/jetstream-pytorch/tokenizer.model --checkpoint_path=/home/ray/jetstream-pytorch/ckpt --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"

No issues.

python run_server.py --model_name=llama-2 --size=7b --batch_size=32 --max_cache_length=2048 --tokenizer_path=/home/ray/jetstream-pytorch/tokenizer.model --checkpoint_path=/home/ray/jetstream-pytorch/ckpt --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"

No issues.

python run_server_with_ray.py --tpu_chips=16 --model_name=llama-2 --size=7b --batch_size=32 --max_cache_length=2048 --tokenizer_path=/home/ray/jetstream-pytorch/tokenizer.model --checkpoint_path=/home/ray/jetstream-pytorch/ckpt --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"

This would return the above problem. Debugging the code further, it seems like the stop token was returned from the model:

I0627 11:25:18.014765 137073235306240 orchestrator.py:741] >>>>data: {data}
2024-06-27 11:25:18,046 - root - INFO - Generate engine 0 step 202 - slots free : 31 / 32, took 40520.31ms
I0627 11:25:18.046449 137073243698944 orchestrator.py:678] Generate engine 0 step 202 - slots free : 31 / 32, took 40520.31ms
2024-06-27 11:25:18,046 - root - INFO - Generate thread making a decision with: prefill_backlog=0 generate_free_slots=31
I0627 11:25:18.046644 137073243698944 orchestrator.py:588] Generate thread making a decision with: prefill_backlog=0 generate_free_slots=31
2024-06-27 11:25:18,047 - root - INFO - Complete [False], slot_tokens: [[13]], slot_lengths: [2]
I0627 11:25:18.047123 137073235306240 token_utils.py:194] Complete [False], slot_tokens: [[13]], slot_lengths: [2]
2024-06-27 11:25:18,047 - root - INFO - Sample idx: 0 Speculation idx: 0 Token: 13
I0627 11:25:18.047416 137073235306240 token_utils.py:209] Sample idx: 0 Speculation idx: 0 Token: 13
2024-06-27 11:25:18,047 - root - INFO - Return samples [ReturnSample(text=['<0x0A>'], token_ids=[13])]
I0627 11:25:18.047532 137073235306240 token_utils.py:230] Return samples [ReturnSample(text=['<0x0A>'], token_ids=[13])]
2024-06-27 11:25:18,047 - root - INFO - >>>>results: [ReturnSample(text=['<0x0A>'], token_ids=[13])] complete: [False]
I0627 11:25:18.047641 137073235306240 orchestrator.py:725] >>>>results: [ReturnSample(text=['<0x0A>'], token_ids=[13])] complete: [False]
2024-06-27 11:25:18,047 - root - INFO - Detokenizing generate step 201 took 1.05ms
I0627 11:25:18.047804 137073235306240 orchestrator.py:734] Detokenizing generate step 201 took 1.05ms
2024-06-27 11:25:18,067 - root - INFO - Generate engine 0 step 203 - slots free : 31 / 32, took 20.82ms
I0627 11:25:18.067497 137073243698944 orchestrator.py:678] Generate engine 0 step 203 - slots free : 31 / 32, took 20.82ms
2024-06-27 11:25:18,068 - root - INFO - Complete [False], slot_tokens: [[0]], slot_lengths: [3]
I0627 11:25:18.068114 137073235306240 token_utils.py:194] Complete [False], slot_tokens: [[0]], slot_lengths: [3]
2024-06-27 11:25:18,068 - root - INFO - Sample idx: 0 Speculation idx: 0 Token: 0
I0627 11:25:18.068206 137073235306240 token_utils.py:209] Sample idx: 0 Speculation idx: 0 Token: 0
2024-06-27 11:25:18,068 - root - INFO - >>>complete: tok_id: 0 stop_tokens:  {0, 2} valid: 1
I0627 11:25:18.068270 137073235306240 token_utils.py:216] >>>complete: tok_id: 0 stop_tokens:  {0, 2} valid: 1
2024-06-27 11:25:18,068 - root - INFO - Return samples [ReturnSample(text=[], token_ids=[])]
I0627 11:25:18.068324 137073235306240 token_utils.py:230] Return samples [ReturnSample(text=[], token_ids=[])]
2024-06-27 11:25:18,068 - root - INFO - >>>>results: [ReturnSample(text=[], token_ids=[])] complete: [ True]
I0627 11:25:18.068485 137073235306240 orchestrator.py:725] >>>>results: [ReturnSample(text=[], token_ids=[])] complete: [ True]
2024-06-27 11:25:18,068 - root - INFO - Detokenizing generate step 202 took 0.72ms
I0627 11:25:18.068594 137073235306240 orchestrator.py:734] Detokenizing generate step 202 took 0.72ms
2024-06-27 11:25:18,088 - root - INFO - Generate engine 0 step 204 - slots free : 31 / 32, took 21.10ms
I0627 11:25:18.088757 137073243698944 orchestrator.py:678] Generate engine 0 step 204 - slots free : 31 / 32, took 21.10ms
2024-06-27 11:25:18,089 - root - INFO - Detokenizing generate step 203 took 0.04ms
I0627 11:25:18.089094 137073235306240 orchestrator.py:734] Detokenizing generate step 203 took 0.04ms

This only repros with run_server_with_ray, and only if the batch_size is set to greater than 1.

FanhaiLu1 commented 6 days ago

Hi richard,

I tested the llama-2 7B with run_server_with_ray.py (--batch_size=32). Instead of sent request one by one, I use benchmark script to send 200 request and got 198 response back. I verified the response, they are accuracy and correct, here is on example:

{

"prompt": "Can you explain the benefit of Neo4J and why I should use it over a more generic relational database", "original_output": "Neo4j is a graph database, which is a type of database that is designed to store and process data as a graph. This means that it is particularly well-suited for storing data that is naturally represented as a network, such as the connections between people in a social network or the relationships between different pieces of data in a recommendation system.\n\nOne of the key benefits of using a graph database like Neo4j is that it allows you to easily represent complex relationships between data, and to quickly and efficiently query the database to find data that meets certain criteria. For example, in a social network, you might use Neo4j to find the shortest path between two people, or to identify clusters of friends within the network.\n\nAnother key benefit of Neo4j is that it is highly scalable, which means that it can handle large amounts of data and support high levels of concurrency. This makes it a good choice for applications that require real-time data processing or that need to support a large number of users.\n\nIn general, Neo4j is a good choice if you have data that is naturally represented as a graph, and you need to perform complex queries on that data to find relationships or patterns. It may be less appropriate for applications that require highly structured, tabular data, or that need to support transactions and other ACID properties.", "generated_text": "like MySQL?\n\nNeo4j is a graph database, which means it stores data as a network of interconnected nodes or vertices, rather than as a traditional table-based relational database like MySQL. Here are some benefits of using Neo4j over a more generic relational database like MySQL:\n\n\n1. Handling complex relationships : Neo4j is designed to handle complex, hierarchical, and networked data structures, which are difficult or impossible to model in a traditional relational database. For example, in a social network, each user can have many followers, and each follower can follow many users. In a graph database, these relationships can be represented as a single node for each user, with edges connecting the nodes to represent the follow and follower relationships.\n\n\n2. Flexible schema : In a traditional relational database, the schema must be defined upfront and can't be changed easily. In contrast, Neo4j has a flexible schema that can be evolved over time as the data and use cases change. This makes it easier to adapt to new use cases and to incorporate new data sources.\n\n\n3. Real-time querying : Neo4j is designed for real-time querying and can handle complex, graph-based queries much faster than a traditional rel", "success": true, "latency": 250.83245306299068, "prompt_len": 23, "sample_idx": 39736 },

Are you run in on GKE? Can you use main branch latest code and run with a bechmark test?