google / objax

Apache License 2.0
767 stars 77 forks source link

Calling objax.random.* from within a jitted function breaks the DEFAULT_GENERATOR #147

Open carlini opened 3 years ago

carlini commented 3 years ago

Consider the following code that creates a new linear model, creates a predict op that adds some randomness to the logits, jits it, runs it, and then tries to create a new linear layer.

new_model = objax.nn.Linear(3, 10)

def predict_op(x):
    return new_model(x)+objax.random.normal((10,))

predict = objax.Jit(predict_op, new_model.vars())
predict(np.zeros((1,3)))

# everything is fine here.
# the code is wrong, because we're going to re-use the same randomness over and over, but so far the world looks good

objax.nn.Linear(3, 10) # BOOM!

This gives the very helpful error message

Traceback (most recent call last):
  File "wat.py", line 22, in <module>
    objax.nn.Linear(3, 10)
  File "/opt/conda/lib/python3.7/site-packages/objax/nn/layers.py", line 282, in __init__
    self.w = TrainVar(w_init((nin, nout)))
  File "/opt/conda/lib/python3.7/site-packages/objax/nn/init.py", line 106, in xavier_normal
    return random.normal(shape, stddev=gain * xavier_normal_gain(shape))
  File "/opt/conda/lib/python3.7/site-packages/objax/random/random.py", line 62, in normal
    return jr.normal(generator(), shape=shape) * stddev + mean
  File "/opt/conda/lib/python3.7/site-packages/objax/random/random.py", line 49, in __call__
    return self.key.split(1)[0]
  File "/opt/conda/lib/python3.7/site-packages/objax/variable.py", line 180, in split
    values = jr.split(self.value, n + 1)
  File "/opt/conda/lib/python3.7/site-packages/jax/random.py", line 281, in split
    return _split(key, int(num))  # type: ignore
jax.traceback_util.FilteredStackTrace: jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line /opt/conda/lib/python3.7/site-packages/objax/variable.py:181 (split).

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "wat.py", line 22, in <module>
    objax.nn.Linear(3, 10)
  File "/opt/conda/lib/python3.7/site-packages/objax/nn/layers.py", line 282, in __init__
    self.w = TrainVar(w_init((nin, nout)))
  File "/opt/conda/lib/python3.7/site-packages/objax/nn/init.py", line 106, in xavier_normal
    return random.normal(shape, stddev=gain * xavier_normal_gain(shape))
  File "/opt/conda/lib/python3.7/site-packages/objax/random/random.py", line 62, in normal
    return jr.normal(generator(), shape=shape) * stddev + mean
  File "/opt/conda/lib/python3.7/site-packages/objax/random/random.py", line 49, in __call__
    return self.key.split(1)[0]
  File "/opt/conda/lib/python3.7/site-packages/objax/variable.py", line 180, in split
    values = jr.split(self.value, n + 1)
  File "/opt/conda/lib/python3.7/site-packages/jax/random.py", line 281, in split
    return _split(key, int(num))  # type: ignore
  File "/opt/conda/lib/python3.7/site-packages/jax/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/jax/api.py", line 219, in f_jitted
    donated_invars=donated_invars)
  File "/opt/conda/lib/python3.7/site-packages/jax/core.py", line 1174, in bind
    return call_bind(self, fun, *args, **params)
  File "/opt/conda/lib/python3.7/site-packages/jax/core.py", line 1163, in call_bind
    tracers = map(top_trace.full_raise, args)
  File "/opt/conda/lib/python3.7/site-packages/jax/util.py", line 35, in safe_map
    return list(map(f, *args))
  File "/opt/conda/lib/python3.7/site-packages/jax/core.py", line 356, in full_raise
    val._assert_live()
  File "/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 918, in _assert_live
    raise core.escaped_tracer_error(msg)
jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line /opt/conda/lib/python3.7/site-packages/objax/variable.py:181 (split).

Here is, I think, what is going on: when we call objax.Jit() on the function, JaX goes on in the backend and traces the random number generator. Because we split the state of the jitted variable (!!) this will assign a Tracer to the DEFAULT_GENERATOR state. Then the next time we try and create a linear layer, we want to initialize it, which calls the default generator, and everything dies.

At the very least, we should just forbid the user from doing this---the code is wrong, and jitting in this way isn't the proper behavior. Better might be to warn the user, but let them shoot themself in the foot if they really want to, but not burn the world down.

david-berthelot commented 3 years ago

I guess, we could simply add DEFAULT_GENERATOR to Jit variables (and other primitives like vectorize and co) so it would just work. Not passing it is an error in any case since the splitting will mess with Jax tracing if not passed.