AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 275 forks source link

add kl divergence for forward_pass_logit_checker #832

Closed ZhaoyueCheng closed 1 month ago

ZhaoyueCheng commented 1 month ago

Description

Test

rdyro commented 1 month ago

Great, thanks for implementing this! This looks good to me