google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.82k stars 2.73k forks source link

🔪 Remaining Sharp Bit TODOs 🔪 #9952

Open levskaya opened 2 years ago

levskaya commented 2 years ago

We could do with sprucing up the Sharp Bits with common problems we've encountered in user code since it was first written.

Top of the list is documenting matmul / conv op precision issues:

We should add some other ideas here.

mattjj commented 2 years ago

Just adding context: context manager for precision is defined here and there are some words about it in #6143.

levskaya commented 2 years ago

Others sharp bits:

mattjj commented 2 years ago

Fantastic!

On that first bullet, we could also mention checkify cc @LenaMartens

LenaMartens commented 2 years ago

Nice, +1 on the danger of weak_type=True for triggering recompilation, we've had people ask for more documentation on that. I might try and add that.

levskaya commented 2 years ago

Also, making sure that the async set of JAX calls used in a training loop don't introduce blocking calls that will kill dispatch pipelining efficiency (e.g. trivial host-side metrics fn or similar) - one of the most common performance mistakes I see (maybe belongs in a separate performance gotchas doc... not sure)

jakevdp commented 2 years ago

I like the idea of having a new dedicated doc for performance tips and pitfalls

jakevdp commented 2 years ago

Regarding reworking the Sharp Bits doc, I recently added a section on miscellaneous divergences between numpy and JAX. It might be nice to rearrange things so all the differences between numpy and JAX are listed briefly under a single heading, perhaps with links to deeper discussion later in the doc.

nalzok commented 2 years ago

Regarding the "jit caching behavior", is there any chance you could cache the compiled result to the file system so that it can persist across runs? In my development cycle, I typically change some hyperparameters and re-run the experiment. It's a little frustrating that each time I have to wait for the JIT compilation, even if I have compiled the exact same code multiple times.

I am under the impression that this won't be too hard to implement, since we already have a hashing/caching mechanism. All it takes is writing the emitted XLA program to the disk. Should I open a new issue for this?

jakevdp commented 2 years ago

@nalzok - there is currently an implementation of this, but only for TPU. See https://github.com/google/jax/tree/main/jax/experimental/compilation_cache for details, and https://github.com/google/jax/issues/2490 where this kind of request is tracked.

JeppeKlitgaard commented 1 year ago

I have a fairly RNG generation-heavy workload that I am running on Cloud TPU and was googling around to try and understand the xla_tpu_spmd_rng_bit_generator_unsafe flag but only found this thread and a brief mention in the JAX documentation. The quality of randomness is not critical for me. Am I right in assuming this flags improves performance but at the cost of using a less well-understood algorithm underneath?

levskaya commented 1 year ago

@JeppeKlitgaard - yeah, it uses an adhoc method of splitting keys that we don't have theoretical justification for (and in fact we don't really have well established statistical tests for split-chain decorrelation when it comes to splittable PRNG systems). That said, it compiles and runs fast, and it's almost certainly good enough for e.g. dropout masks in the context of SGD training of NNs (and we've used it for that with no observed ill effects for some time). I'd be a bit more careful if I were doing classic MCMC or something.