google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
32 stars 14 forks source link

How to run benchmark on CloudTPU v4-8 #82

Closed JackCaoG closed 3 months ago

JackCaoG commented 3 months ago

I am trying to run benchmark on TPU v4-8. I followed the instructions on the README and download the weight and tokenizer for the Llama2.

I first try to do local run with

/jetstream-pytorch# python run_interactive.py --size=7b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
2024-05-15 23:04:15.206016: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Number of param Gbytes: 12.551521301269531
Number of param:  292
Initialize engine 0.3343957440229133
Name: tok_embeddings.weight, shape: (32000, 4096) x bfloat16
Name: layers.0.attention.wo.weight, shape: (4096, 4096) x bfloat16
Name: layers.0.attention.wq.weight, shape: (4096, 4096) x bfloat16
Name: layers.0.attention.wk.weight, shape: (4096, 4096) x bfloat16
Name: layers.0.attention.wv.weight, shape: (4096, 4096) x bfloat16
Name: layers.0.feed_forward.w1.weight, shape: (11008, 4096) x bfloat16
Name: layers.0.feed_forward.w2.weight, shape: (4096, 11008) x bfloat16
Name: layers.0.feed_forward.w3.weight, shape: (11008, 4096) x bfloat16
Name: layers.0.attention_norm.weight, shape: (4096,) x bfloat16
Name: layers.0.ffn_norm.weight, shape: (4096,) x bfloat16
Name: norm.weight, shape: (4096,) x bfloat16
Name: output.weight, shape: (32000, 4096) x bfloat16
Name: freqs_cis, shape: (2048, 64) x complex64
Load params  15.498905850981828
W0515 23:04:36.588934 140360324777792 vocabularies.py:411] T5 library uses PAD_ID=0, which is different from the sentencepiece vocabulary, which defines pad_id=-1
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
---- Input prompts are: I believe the meaning of life is
---- Encoded tokens are: [   1  306 4658  278 6593  310 2834  338    0    0    0    0    0    0
    0    0]
[Traced<ShapedArray(int32[16])>with<DynamicJaxprTrace(level=2/0)>]
> /usr/local/lib/python3.10/inspect.py(3108)_bind()
-> raise TypeError('too many positional arguments') from None
(Pdb) l
3103                    # We have a positional argument to process
3104                    try:
3105                        param = next(parameters)
3106                    except StopIteration:
3107                        breakpoint()
3108 ->                     raise TypeError('too many positional arguments') from None
3109                    else:
3110                        if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY):
3111                            # Looks like we have no parameter for this positional
3112                            # argument
3113                            raise TypeError(

but it failed with too many positional arguments.

I then tried to do a benchmark run with python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs --warm-up=True. It seems like wram_up flag is not longer there

benchmark_serving.py: error: unrecognized arguments: --warm-up=True

after I remove it, the new failure is

Making request
Making request
Making request
Traceback (most recent call last):
  File "/workspaces/dk3/jetstream-pytorch/deps/JetStream/benchmarks/benchmark_serving.py", line 782, in <module>
    main(parsed_args)
  File "/workspaces/dk3/jetstream-pytorch/deps/JetStream/benchmarks/benchmark_serving.py", line 578, in main
    benchmark_result, request_outputs = asyncio.run(
  File "/usr/local/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/workspaces/dk3/jetstream-pytorch/deps/JetStream/benchmarks/benchmark_serving.py", line 450, in benchmark
    outputs = await asyncio.gather(*tasks)
  File "/workspaces/dk3/jetstream-pytorch/deps/JetStream/benchmarks/benchmark_serving.py", line 407, in send_request
    generated_token_list, ttft, latency = await grpc_async_request(
  File "/workspaces/dk3/jetstream-pytorch/deps/JetStream/benchmarks/benchmark_serving.py", line 381, in grpc_async_request
    async for sample_list in response:
  File "/usr/local/lib/python3.10/site-packages/grpc/aio/_call.py", line 356, in _fetch_stream_responses
    await self._raise_for_status()
  File "/usr/local/lib/python3.10/site-packages/grpc/aio/_call.py", line 263, in _raise_for_status
    raise _create_rpc_error(
grpc.aio._call.AioRpcError: <AioRpcError of RPC that terminated with:
    status = StatusCode.UNAVAILABLE
    details = "failed to connect to all addresses; last error: UNKNOWN: ipv4:0.0.0.0:9000: Failed to connect to remote host: Connection refused"
    debug_error_string = "UNKNOWN:Error received from peer  {created_time:"2024-05-15T23:07:35.419653175+00:00", grpc_status:14, grpc_message:"failed to connect to all addresses; last error: UNKNOWN: ipv4:0.0.0.0:9000: Failed to connect to remote host: Connection refused"}"

I then try to start the server with

jetstream-pytorch# python run_server.py --param_size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8 --model=llama-2

now it failed with

2024-05-15 23:10:02.139303: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2024-05-15 23:10:04,525 - jax._src.dispatch - DEBUG - Finished tracing + transforming convert_element_type for pjit in 0.00039696693420410156 sec
2024-05-15 23:10:04,525 - jax._src.xla_bridge - DEBUG - No jax_plugins namespace packages available
2024-05-15 23:10:04,538 - jax._src.xla_bridge - DEBUG - Initializing backend 'tpu'
2024-05-15 23:10:07,296 - jax._src.xla_bridge - DEBUG - Backend 'tpu' initialized
2024-05-15 23:10:07,296 - jax._src.xla_bridge - DEBUG - Initializing backend 'cpu'
2024-05-15 23:10:07,339 - jax._src.xla_bridge - DEBUG - Backend 'cpu' initialized
2024-05-15 23:10:07,342 - jax._src.interpreters.pxla - DEBUG - Compiling convert_element_type for with global shapes and types [ShapedArray(int64[])]. Argument mapping: [UnspecifiedValue].
2024-05-15 23:10:07,344 - jax._src.dispatch - DEBUG - Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.0024080276489257812 sec
2024-05-15 23:10:07,345 - jax._src.compiler - DEBUG - get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
2024-05-15 23:10:07,345 - jax._src.compiler - DEBUG - get_compile_options XLA-AutoFDO profile: using XLA-AutoFDO profile version -1
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - get_cache_key hash of serialized computation: 9b22b9c19f05dd8133ae8a5a2050357629ebcb51a48b3bb309de5f6bca1cb75c
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - get_cache_key hash after serializing computation: 9b22b9c19f05dd8133ae8a5a2050357629ebcb51a48b3bb309de5f6bca1cb75c
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - get_cache_key hash of serialized jax_lib version: a1df1fad9908669f67d4192ac24d3ac2f035c18fbfeae09917f9a757d81b169b
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - get_cache_key hash after serializing jax_lib version: e7eb13c076f6b06ca898aa33cb2cd8894ec60abfb3ff9d1e036a197b4bef0017
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - Including XLA flag in cache key: --xla_tpu_use_enhanced_launch_barrier=true
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - Including XLA flag in cache key: --xla_tpu_use_enhanced_launch_barrier=true
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - get_cache_key hash of serialized XLA flags: 542121283c29d07821f077110af0f57c569ed8ea4865a1e540e2676eabe4a410
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - get_cache_key hash after serializing XLA flags: c47f531da3be18db3fbaf17fd8d9a755e1a27e17ad576f4d00746acead1f79a9
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - get_cache_key hash of serialized compile_options: 4b173992891f565fdb6ab35b3f9ff4821d922f92051e08788649da7420e48e8d
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - get_cache_key hash after serializing compile_options: 4b4e5c57279092972bced7d06fbf54b87188f1cff72d4776b23be7926b185e68
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - get_cache_key hash of serialized accelerator_config: e629e0039fb5ebd0b0c09477f8ecd05dfdc0380fb83107fb2ff6a3a0854873d2
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - get_cache_key hash after serializing accelerator_config: 672da232108087e9f90279af3f2ade8a770b77b8104ef58ba8025fde1156fecd
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - get_cache_key hash of serialized compression: 0ea55c28f8014d8886b6248fe3da5d588f55c0823847a6b4579f1131b051b5e2
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - get_cache_key hash after serializing compression: a2ebd972881825165cc8faca346f3f55853e3f0087e5820905c46b3352ad9729
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - get_cache_key hash of serialized custom_hook: e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
2024-05-15 23:10:07,346 - jax._src.cache_key - DEBUG - get_cache_key hash after serializing custom_hook: a2ebd972881825165cc8faca346f3f55853e3f0087e5820905c46b3352ad9729
2024-05-15 23:10:07,346 - jax._src.compilation_cache - DEBUG - get_executable_and_time: cache is disabled/not initialized
2024-05-15 23:10:07,392 - jax._src.compiler - DEBUG - Not writing persistent cache entry for 'jit_convert_element_type' because it took < 1.00 seconds to compile (0.05s)
2024-05-15 23:10:07,392 - jax._src.dispatch - DEBUG - Finished XLA compilation of jit(convert_element_type) in 0.04703927040100098 sec
2024-05-15 23:10:07,395 - jax._src.dispatch - DEBUG - Finished tracing + transforming bitwise_and for pjit in 0.0004899501800537109 sec
2024-05-15 23:10:07,396 - jax._src.dispatch - DEBUG - Finished tracing + transforming _threefry_seed for pjit in 0.0018978118896484375 sec
2024-05-15 23:10:07,397 - jax._src.interpreters.pxla - DEBUG - Compiling _threefry_seed for with global shapes and types [ShapedArray(int64[])]. Argument mapping: [UnspecifiedValue].
2024-05-15 23:10:07,401 - jax._src.dispatch - DEBUG - Finished jaxpr to MLIR module conversion jit(_threefry_seed) in 0.004235267639160156 sec
2024-05-15 23:10:07,401 - jax._src.compiler - DEBUG - get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
2024-05-15 23:10:07,402 - jax._src.compiler - DEBUG - get_compile_options XLA-AutoFDO profile: using XLA-AutoFDO profile version -1
2024-05-15 23:10:07,402 - jax._src.cache_key - DEBUG - get_cache_key hash of serialized computation: 8fd441afb36ae3174d422336e96edc5ca8a8bfef88fa2623cb9d73106df51584
2024-05-15 23:10:07,402 - jax._src.cache_key - DEBUG - get_cache_key hash after serializing computation: 8fd441afb36ae3174d422336e96edc5ca8a8bfef88fa2623cb9d73106df51584
2024-05-15 23:10:07,402 - jax._src.cache_key - DEBUG - get_cache_key hash of serialized jax_lib version: a1df1fad9908669f67d4192ac24d3ac2f035c18fbfeae09917f9a757d81b169b
2024-05-15 23:10:07,402 - jax._src.cache_key - DEBUG - get_cache_key hash after serializing jax_lib version: cca438ca1aa8a60a6652dd9cb9c4c43985aba555e8bee3908bd04a5413156e74
2024-05-15 23:10:07,402 - jax._src.cache_key - DEBUG - Including XLA flag in cache key: --xla_tpu_use_enhanced_launch_barrier=true
2024-05-15 23:10:07,402 - jax._src.cache_key - DEBUG - Including XLA flag in cache key: --xla_tpu_use_enhanced_launch_barrier=true
2024-05-15 23:10:07,402 - jax._src.cache_key - DEBUG - get_cache_key hash of serialized XLA flags: 542121283c29d07821f077110af0f57c569ed8ea4865a1e540e2676eabe4a410
2024-05-15 23:10:07,402 - jax._src.cache_key - DEBUG - get_cache_key hash after serializing XLA flags: 4cf01b54ff647228c289482b9b278cccc23bfe2531eb7f8843f78538d827b89e
2024-05-15 23:10:07,402 - jax._src.cache_key - DEBUG - get_cache_key hash of serialized compile_options: c8121bb1813529baac77e2b01b093ae98d127d1302fcac282285f208d8e2ee1e
2024-05-15 23:10:07,402 - jax._src.cache_key - DEBUG - get_cache_key hash after serializing compile_options: f0b81240f6d2bafd46e0eb5b437b6e1da700211fb1cae0f986b2be4530a8614d
2024-05-15 23:10:07,403 - jax._src.cache_key - DEBUG - get_cache_key hash of serialized accelerator_config: e629e0039fb5ebd0b0c09477f8ecd05dfdc0380fb83107fb2ff6a3a0854873d2
2024-05-15 23:10:07,403 - jax._src.cache_key - DEBUG - get_cache_key hash after serializing accelerator_config: 95df519473fb00de03de8d2d106ae7f07740e8db0a3373daafad31cac4ef5a53
2024-05-15 23:10:07,403 - jax._src.cache_key - DEBUG - get_cache_key hash of serialized compression: 0ea55c28f8014d8886b6248fe3da5d588f55c0823847a6b4579f1131b051b5e2
2024-05-15 23:10:07,403 - jax._src.cache_key - DEBUG - get_cache_key hash after serializing compression: 1f8bb5ae6d28dc076530283f1d0e423298cf809e4b3a27f4eda8c1df7c4baedf
2024-05-15 23:10:07,403 - jax._src.cache_key - DEBUG - get_cache_key hash of serialized custom_hook: e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
2024-05-15 23:10:07,403 - jax._src.cache_key - DEBUG - get_cache_key hash after serializing custom_hook: 1f8bb5ae6d28dc076530283f1d0e423298cf809e4b3a27f4eda8c1df7c4baedf
2024-05-15 23:10:07,403 - jax._src.compilation_cache - DEBUG - get_executable_and_time: cache is disabled/not initialized
2024-05-15 23:10:07,433 - jax._src.compiler - DEBUG - Not writing persistent cache entry for 'jit__threefry_seed' because it took < 1.00 seconds to compile (0.03s)
2024-05-15 23:10:07,433 - jax._src.dispatch - DEBUG - Finished XLA compilation of jit(_threefry_seed) in 0.03176140785217285 sec
FATAL Flags parsing error: Unknown command line flag 'model_name'
Pass --helpshort or --helpfull to see help on flags.

I am wondering what did I do wrong here.

raghul-tonita commented 3 months ago

s/--model_name=/--model

There is a bug in the readme, the argument is --model.

There is no argument called model_name, you should remove it from the command.

The sample command has --model affixed to the right, so you will need to just remove the --model_name param.

Look here: https://github.com/google/jetstream-pytorch/blob/f3cf2b73813616fb8f182e7a82006a69c62cb661/run_server.py#L74

JackCaoG commented 3 months ago

Thanks @raghul-tonita if I want to benchmark the throughput what should I do? Do I need to first start the server and then run the benchmark script?

raghul-tonita commented 3 months ago

Yes. I had set up the server and load tested using a variant of https://github.com/google/JetStream/blob/main/jetstream/tools/load_tester.py. In short, setup the server and sent async grpc requests with my custom dataset to benchmark.

You could also use the same dataset that was used to generate the benchmarks report in the benchmarks.md and run the benchmarking script -- I haven't done this, so I might not know.

JackCaoG commented 3 months ago

Thanks @raghul-tonita I am able to start the server with

python run_server.py --param_size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=4 --model=llama-2 --sharding_config="default_shardings/llama.yaml"

I am also able to start the benchmark script that sends grpc request with

jetstream-pytorch/deps/JetStream# python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs
JackCaoG commented 3 months ago

For the server start I have to manually add --sharding_config="default_shardings/llama.yaml" otherwise it will complain it can't find sharding_config

for the benchmark script I found that --warm-up=True is no longer supported so I have to remove that option. Benchmark script eventually output.

Successful requests: 2000
Benchmark duration: 511.800847 s
Total input tokens: 220485
Total generated tokens: 622832
Request throughput: 3.91 requests/s
Input token throughput: 430.80 tokens/s
Output token throughput: 1216.94 tokens/s
Mean TTFT: 361578.43 ms
Median TTFT: 361889.06 ms
P99 TTFT: 493906.13 ms
Mean TPOT: 6019.10 ms
Median TPOT: 1346.14 ms
P99 TPOT: 111032.06 ms

I don't know if I can trust this data because there is a long warm up time in the beginning. The throughput also seems a bit low since I am using v4-8 with llama2-7b.

JackCaoG commented 3 months ago

Ah Thanks Han, the new flag is warmup-first instead of warm-up. With that I see the throuput of

Successful requests: 2000
Benchmark duration: 279.988127 s
Total input tokens: 220485
Total generated tokens: 623221
Request throughput: 7.14 requests/s
Input token throughput: 787.48 tokens/s
Output token throughput: 2225.88 tokens/s
Mean TTFT: 130993.14 ms
Median TTFT: 132435.39 ms
P99 TTFT: 262211.08 ms
Mean TPOT: 2297.64 ms
Median TPOT: 479.33 ms
P99 TPOT: 43248.99 ms
JackCaoG commented 3 months ago

Let me try to submit a pr to fix the README and then we can close this issue.