Closed NeilGirdhar closed 1 year ago
See the linked Jax issue for more ideas about what may be going wrong.
I solved my problem by putting an assertion in jax._src.core._initialize_jax_jit_thread_local_state
. This allowed me to find the one point at which I was forcing premature initialization of Jax and remove it.
Orion freezes when this code is run:
Environment:
Possible solution Could be related to Jax's desire to Jit its random functions.