Open carlini opened 3 years ago
I guess, we could simply add DEFAULT_GENERATOR
to Jit variables (and other primitives like vectorize and co) so it would just work. Not passing it is an error in any case since the splitting will mess with Jax tracing if not passed.
Consider the following code that creates a new linear model, creates a predict op that adds some randomness to the logits, jits it, runs it, and then tries to create a new linear layer.
This gives the very helpful error message
Here is, I think, what is going on: when we call objax.Jit() on the function, JaX goes on in the backend and traces the random number generator. Because we split the state of the jitted variable (!!) this will assign a Tracer to the
DEFAULT_GENERATOR
state. Then the next time we try and create a linear layer, we want to initialize it, which calls the default generator, and everything dies.At the very least, we should just forbid the user from doing this---the code is wrong, and jitting in this way isn't the proper behavior. Better might be to warn the user, but let them shoot themself in the foot if they really want to, but not burn the world down.