jax-ml / coix

Inference Combinators in JAX
https://coix.readthedocs.io/en/latest/
Apache License 2.0
43 stars 2 forks source link

Fix maybe_extract_keys logic and move lambda outside of `util.train` to avoid recompiling lax.cond. #17

Closed fehiepsi closed 1 year ago

fehiepsi commented 1 year ago

Fix maybe_extract_keys logic and move lambda outside of util.train to avoid recompiling lax.cond.

FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/coix/pull/15 from jax-ml:prng 222781afd64b8f77482a0f4adee181830b291783