google / JetStream

JetStream is a throughput and memory optimized engine for LLM inference on XLA devices, starting with TPUs (and GPUs in future -- PRs welcome).
Apache License 2.0
194 stars 24 forks source link

Manual model warmup to resolve AOT model warmup performance degradation #126

Closed vivianrwu closed 1 month ago

vivianrwu commented 1 month ago

Use manual model warmup instead of AOT implemented model warmup, since with AOT, we observe performance degradation at higher batch size of maxtext configuration, mentioned in https://github.com/google/JetStream/pull/92:

  1. OOM at higher batch size (after model warmup, during an active request)
  2. Slower detokenizing generate step time exponentially at higher batch sizes

This has been verified that the detokenizing generate step time remains same as JetStream optimal behavior for all batch sizes.

curl --request POST --header "Content-type: application/json"
 -s localhost:8000/generate --data '{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'
{
    "response": " for data science in 2023?\n\n1. Python\n2. R\n3. SQL\n4. Java\n5. Scala\n\n**Note:** The order is based on popularity and demand in the data science industry in 2023."
}
vivianrwu commented 1 month ago

Do we need to update unit tests?

Unit tests do not need to be updated because it is on the condition of engine.warm

QQ on the description,

  1. we set the max pdbs when we start the server, this value should be within memory cap (based on calculation w the devices used), then it would not OOM right?

Yes, I think the storage of the compiled graphs from AOT and executing it from AOT is what takes up the memory. We observe the OOM at generate request.

  1. Why higher actual batch size would have very slow detokenization? Could you share some investigation or profiles?

Yes, you can reference https://github.com/google/JetStream/pull/92 for some investigations. Also shared the doc internally.

FanhaiLu1 commented 1 month ago

ified that the detokenizing generate step time remains same as JetStream optimal behavior for all batch sizes.

Did you figure out what is the root cause of performance issue and OOM for AOT?

vivianrwu commented 1 month ago

ified that the detokenizing generate step time remains same as JetStream optimal behavior for all batch sizes.

Did you figure out what is the root cause of performance issue and OOM for AOT?

RCA has been attempted and the root cause of OOM can potentially be the added space to save the compiled graphs in executables alongside saving the cache in the compilation cache directory. The performance issue, has not been concluded. Could be unoptimal AOT executables. I can share the investigation offline