Closed georgedahl closed 4 years ago
This is an excellent suggestion.
There are some meaty docstrings in Lax's parallelization primitives, spmd_mnist_classifier_fromscratch.py
, etc., but it would be useful to have a cookbook and API docs.
I believe https://github.com/google/jax/blob/master/cloud_tpu_colabs/Pmap_Cookbook.ipynb solves this, what do you think @mattjj ?
It does indeed. I think the remaining action item is to link the pmap cookbook into the regular JAX documentation a bit better.
(For a long time we had the cookbook ready to go, but we didn't have a free Colab environment with multiple accelerators that we could run it on. We do now because cloud TPU colabs are available.)
It's already linked under pmap (using SPMD Cookbook name), but maybe adding the cloud_tpu_colabs to docs/notebooks would be better?
I think this is solved!
A notebook showing a rich and informative set of pmap examples would be an extremely valuable addition.