jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.09k stars 2.75k forks source link

Add a pmap cookbook #1487

Closed georgedahl closed 4 years ago

georgedahl commented 4 years ago

A notebook showing a rich and informative set of pmap examples would be an extremely valuable addition.

dynamicwebpaige commented 4 years ago

This is an excellent suggestion.

There are some meaty docstrings in Lax's parallelization primitives, spmd_mnist_classifier_fromscratch.py, etc., but it would be useful to have a cookbook and API docs.

joaogui1 commented 4 years ago

I believe https://github.com/google/jax/blob/master/cloud_tpu_colabs/Pmap_Cookbook.ipynb solves this, what do you think @mattjj ?

hawkinsp commented 4 years ago

It does indeed. I think the remaining action item is to link the pmap cookbook into the regular JAX documentation a bit better.

(For a long time we had the cookbook ready to go, but we didn't have a free Colab environment with multiple accelerators that we could run it on. We do now because cloud TPU colabs are available.)

joaogui1 commented 4 years ago

It's already linked under pmap (using SPMD Cookbook name), but maybe adding the cloud_tpu_colabs to docs/notebooks would be better?

mattjj commented 4 years ago

I think this is solved!