google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
32 stars 14 forks source link

Support llama3 #64

Closed bhavya01 closed 4 months ago

bhavya01 commented 4 months ago

Tested with run_interactive.py

python run_interactive.py --size=8b --model=llama-3 --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Also ran the benchmark for llama-2 on tpu v4-8 and got the following numbers:

Successful requests: 1999
Benchmark duration: 393.904737 s
Total input tokens: 220485
Total generated tokens: 608985
Request throughput: 5.07 requests/s
Input token throughput: 559.74 tokens/s
Output token throughput: 1546.02 tokens/s
Mean TTFT: 284112.00 ms
Median TTFT: 284544.72 ms
P99 TTFT: 368924.02 ms
Mean TPOT: 5756.67 ms
Median TPOT: 1095.59 ms
P99 TPOT: 109265.00 ms

Need to run the benchmark script in Jetstream repo to get the metrics for llama-3

FanhaiLu1 commented 4 months ago

Thanks for adding llama3 support! Accuracy is critical, can you share the output result from both llama2 and llama3 from run_interactive?

bhavya01 commented 4 months ago

Thanks for adding llama3 support! Accuracy is critical, can you share the output result from both llama2 and llama3 from run_interactive?

This is the output for both the models: https://gist.github.com/bhavya01/40a344e671a2e5dde980f163141545db

FanhaiLu1 commented 4 months ago

Thanks for adding llama3 support! Accuracy is critical, can you share the output result from both llama2 and llama3 from run_interactive?

This is the output for both the models: https://gist.github.com/bhavya01/40a344e671a2e5dde980f163141545db

Can you use the output without this PR as base to compare (it's hard to know the quality dropped or not without baseline comparison)? If possible, can you do both base and test without quantization?

bhavya01 commented 4 months ago

Thanks for adding llama3 support! Accuracy is critical, can you share the output result from both llama2 and llama3 from run_interactive?

This is the output for both the models: https://gist.github.com/bhavya01/40a344e671a2e5dde980f163141545db

Can you use the output without this PR as base to compare (it's hard to know the quality dropped or not without baseline comparison)? If possible, can you do both base and test without quantization?

Yes makes sense. I did the comparison for LLAMA2 without quantization and both the results seem weird. https://gist.github.com/bhavya01/660cd636d678f42a01501d093d63c2b1

With quantization, they both look pretty similar: I added llama2_before output in this gist: https://gist.github.com/bhavya01/40a344e671a2e5dde980f163141545db

bhavya01 commented 4 months ago

The unit tests are failing as we test against Jetstream v0.2.0. We should have a new Jetstream release this week after which these tests will be fixed.

FanhaiLu1 commented 4 months ago

The unit tests are failing as we test against Jetstream v0.2.0. We should have a new Jetstream release this week after which these tests will be fixed.

@JoeZijunZhou HI Zijun, could you let us know when you plan to create a new Jetstream release? @bhavya01 Giving current test status, we need to tag latest JetStream release before submit this PR.