google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.39k stars 247 forks source link

Fix decode.py to also use first_token from prefill_call #756

Closed vipannalla closed 2 weeks ago

vipannalla commented 2 weeks ago

NOTE: redoing PR-754, I'm running into git issue with that one.

@RissyRan tested these changes using composer + xpk setup -- https://screenshot.googleplex.com/8iDFz44j8haMKeG