Closed fehiepsi closed 1 year ago
Fix maybe_extract_keys logic and move lambda outside of util.train to avoid recompiling lax.cond.
util.train
FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/coix/pull/15 from jax-ml:prng 222781afd64b8f77482a0f4adee181830b291783
Fix maybe_extract_keys logic and move lambda outside of
util.train
to avoid recompiling lax.cond.FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/coix/pull/15 from jax-ml:prng 222781afd64b8f77482a0f4adee181830b291783