google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.1k stars 2.66k forks source link

Consider tracing twice to detect Python side effects #1141

Open jekbradbury opened 4 years ago

jekbradbury commented 4 years ago

User code that relies on Python side effects or interpreter state will behave differently under jit or pmap (in particular, the interpreter state will be effectively frozen at compile time).

One example is rng = random.PRNGKey(int(time.time())), which will freeze the RNG key at compile time and produce the same random numbers at every step. Currently we have no way of detecting that something like this is happening and flagging it as a warning or error.

One way of identifying code that relies on interpreter state would be to trace user code twice and check that the staged-out representation is the same. It would be straightforward to do this for existing initial-style tracers (compare the jaxprs), but I think jit specifically (as a final-style tracer with initial-style output) would require either switching to initial style or building and comparing two XLA HLO computations.

Are there cases this wouldn't detect, or reasons why this would be a bad idea other than performance hits to the cache-miss path? Do we feel confident in the level of determinism in the core infrastructure? Should we throw an error or a warning, and should we enable it by default or put it behind a flag or kwarg? (A similar approach was/is used in the tracing component of the PyTorch JIT, and a somewhat less similar approach is used to detect variable misuse in TF2)

skye commented 4 years ago

This is probably an edge case, but tracing twice would be confusing if you use print statements to see how often compilation is happening. I have vague recollections of people using even fancier trace-time side effects, but I can't remember the specifics now, and maybe that's asking for trouble anyway :)

If we just wanna catch the rng = random.PRNGKey(int(time.time())) case, it might be simpler to expand the DynamicAxisEnv logic to create a context for jit as well as pmap, and have PRNGKey blow up under a jit. Unless there are use cases for creating random numbers in a jit'd function? This would also only catch this one case, not behavior-changing side effects in general. (Although tracing twice would only catch side effects that change the behavior of the second call, and in theory you could have side effects that aren't visible for a couple calls...)