google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.39k stars 2.69k forks source link

Document `pjit(xmap)` / `xmap(pjit)` #13514

Open 8bitmp3 opened 1 year ago

8bitmp3 commented 1 year ago

Proposal:

  1. To document that pjit + pmap aren’t supported

    • Supported alternative: pjit(xmap) / xmap(pjit)
  2. To add to F.A.Q. or The Sharp Bits?

WDYT @skye @jakevdp @mattjj

cc @yashk2810 (now that 0.4/jax.Array has been released)

yashk2810 commented 1 year ago

0.4 is not released yet.

Also Adam and I have a doc in progress about this.

8bitmp3 commented 1 year ago

Thanks @yashk2810 Saw the v0.4 announcement on jax.readthedocs.io and thought it was:

image

Which doc are you referring to?

skye commented 1 year ago

Ah, in hindsight, we probably shouldn't have published that doc change before 0.4.0 was released. But it's gonna be released soon I think :)

yashk2810 commented 1 year ago

Ehh, it's fine. We can add JAX 0.4.0 (coming soon) if more people start asking.